This commit is contained in:
Patrick 2024-11-06 22:55:07 +01:00
parent 6688a48fb0
commit cba77c00ba
23 changed files with 2309 additions and 0 deletions

24
CMakeLists.txt Normal file
View File

@ -0,0 +1,24 @@
cmake_minimum_required(VERSION 3.20)
project(OAuth2Client C)
set(CMAKE_C_STANDARD 23)
set(CMAKE_BUILD_TYPE "DEBUG")
set(CMAKE_EXPORT_COMPILE_COMMANDS on)
find_package(OpenSSL REQUIRED)
add_library(OAuth2Lib
util/thread_queue.c
util/sorted_str_set.c
networking.c
pkce.c
base64.c
ssl.c
log.c
server.c)
target_include_directories(OAuth2Lib PRIVATE ${OpenSSL_INCLUDE_DIRS})
target_link_libraries(OAuth2Lib PRIVATE OpenSSL::SSL OpenSSL::Crypto)
add_executable(main main.c)
target_link_libraries(main OAuth2Lib)

286
base64.c Normal file
View File

@ -0,0 +1,286 @@
#include "base64.h"
#include <memory.h>
#include <stdlib.h>
static const char base64url_table[65] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
// http://web.mit.edu/freebsd/head/contrib/wpa/src/utils/base64.c
/*
* Base64 encoding/decoding (RFC1341)
* Copyright (c) 2005-2011, Jouni Malinen <j@w1.fi>
*
* This software may be distributed under the terms of the BSD license.
* See README for more details.
*/
static const unsigned char base64_table[65] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
/**
* base64_encode - Base64 encode
* @src: Data to be encoded
* @len: Length of the data to be encoded
* @out_len: Pointer to output length variable, or %NULL if not used
* Returns: Allocated buffer of out_len bytes of encoded data,
* or %NULL on failure
*
* Caller is responsible for freeing the returned buffer. Returned buffer is
* nul terminated to make it easier to use as a C string. The nul terminator is
* not included in out_len.
*/
char * base64_encode(const uint8_t *src, size_t len,
size_t *out_len)
{
unsigned char *out, *pos;
const unsigned char *end, *in;
size_t olen;
int line_len;
olen = len * 4 / 3 + 4; /* 3-byte blocks to 4-byte */
olen += olen / 72; /* line feeds */
olen++; /* nul termination */
if (olen < len)
return NULL; /* integer overflow */
out = malloc(olen);
if (out == NULL)
return NULL;
end = src + len;
in = src;
pos = out;
line_len = 0;
while (end - in >= 3) {
*pos++ = base64_table[in[0] >> 2];
*pos++ = base64_table[((in[0] & 0x03) << 4) | (in[1] >> 4)];
*pos++ = base64_table[((in[1] & 0x0f) << 2) | (in[2] >> 6)];
*pos++ = base64_table[in[2] & 0x3f];
in += 3;
line_len += 4;
if (line_len >= 72) {
*pos++ = '\n';
line_len = 0;
}
}
if (end - in) {
*pos++ = base64_table[in[0] >> 2];
if (end - in == 1) {
*pos++ = base64_table[(in[0] & 0x03) << 4];
*pos++ = '=';
} else {
*pos++ = base64_table[((in[0] & 0x03) << 4) |
(in[1] >> 4)];
*pos++ = base64_table[(in[1] & 0x0f) << 2];
}
*pos++ = '=';
line_len += 4;
}
if (line_len)
*pos++ = '\n';
*pos = '\0';
if (out_len)
*out_len = pos - out;
return out;
}
/**
* base64_decode - Base64 decode
* @src: Data to be decoded
* @len: Length of the data to be decoded
* @out_len: Pointer to output length variable
* Returns: Allocated buffer of out_len bytes of decoded data,
* or %NULL on failure
*
* Caller is responsible for freeing the returned buffer.
*/
uint8_t* base64_decode(const char *src, size_t len,
size_t *out_len)
{
unsigned char dtable[256], *out, *pos, block[4], tmp;
size_t i, count, olen;
int pad = 0;
memset(dtable, 0x80, 256);
for (i = 0; i < sizeof(base64_table) - 1; i++)
dtable[base64_table[i]] = (unsigned char) i;
dtable['='] = 0;
count = 0;
for (i = 0; i < len; i++) {
if (dtable[src[i]] != 0x80)
count++;
}
if (count == 0 || count % 4)
return NULL;
olen = count / 4 * 3;
pos = out = malloc(olen);
if (out == NULL)
return NULL;
count = 0;
for (i = 0; i < len; i++) {
tmp = dtable[src[i]];
if (tmp == 0x80)
continue;
if (src[i] == '=')
pad++;
block[count] = tmp;
count++;
if (count == 4) {
*pos++ = (block[0] << 2) | (block[1] >> 4);
*pos++ = (block[1] << 4) | (block[2] >> 2);
*pos++ = (block[2] << 6) | block[3];
count = 0;
if (pad) {
if (pad == 1)
pos--;
else if (pad == 2)
pos -= 2;
else {
/* Invalid padding */
free(out);
return NULL;
}
break;
}
}
}
*out_len = pos - out;
return out;
}
char * base64url_encode(const uint8_t *src, size_t len,
size_t *out_len)
{
unsigned char *out, *pos;
const unsigned char *end, *in;
size_t olen;
int line_len;
olen = len * 4 / 3 + 4; /* 3-byte blocks to 4-byte */
olen += olen / 72; /* line feeds */
olen++; /* nul termination */
if (olen < len)
return NULL; /* integer overflow */
out = malloc(olen);
if (out == NULL)
return NULL;
end = src + len;
in = src;
pos = out;
line_len = 0;
while (end - in >= 3) {
*pos++ = base64url_table[in[0] >> 2];
*pos++ = base64url_table[((in[0] & 0x03) << 4) | (in[1] >> 4)];
*pos++ = base64url_table[((in[1] & 0x0f) << 2) | (in[2] >> 6)];
*pos++ = base64url_table[in[2] & 0x3f];
in += 3;
line_len += 4;
if (line_len >= 72) {
*pos++ = '\n';
line_len = 0;
}
}
if (end - in) {
*pos++ = base64url_table[in[0] >> 2];
if (end - in == 1) {
*pos++ = base64url_table[(in[0] & 0x03) << 4];
*pos++ = '=';
} else {
*pos++ = base64url_table[((in[0] & 0x03) << 4) |
(in[1] >> 4)];
*pos++ = base64url_table[(in[1] & 0x0f) << 2];
}
*pos++ = '=';
line_len += 4;
}
if (line_len)
*pos++ = '\n';
*pos = '\0';
if (out_len)
*out_len = pos - out;
return out;
}
/**
* base64_decode - Base64 decode
* @src: Data to be decoded
* @len: Length of the data to be decoded
* @out_len: Pointer to output length variable
* Returns: Allocated buffer of out_len bytes of decoded data,
* or %NULL on failure
*
* Caller is responsible for freeing the returned buffer.
*/
uint8_t* base64url_decode(const char *src, size_t len,
size_t *out_len)
{
unsigned char dtable[256], *out, *pos, block[4], tmp;
size_t i, count, olen;
int pad = 0;
memset(dtable, 0x80, 256);
for (i = 0; i < sizeof(base64url_table) - 1; i++)
dtable[base64url_table[i]] = (unsigned char) i;
dtable['='] = 0;
count = 0;
for (i = 0; i < len; i++) {
if (dtable[src[i]] != 0x80)
count++;
}
if (count == 0 || count % 4)
return NULL;
olen = count / 4 * 3;
pos = out = malloc(olen);
if (out == NULL)
return NULL;
count = 0;
for (i = 0; i < len; i++) {
tmp = dtable[src[i]];
if (tmp == 0x80)
continue;
if (src[i] == '=')
pad++;
block[count] = tmp;
count++;
if (count == 4) {
*pos++ = (block[0] << 2) | (block[1] >> 4);
*pos++ = (block[1] << 4) | (block[2] >> 2);
*pos++ = (block[2] << 6) | block[3];
count = 0;
if (pad) {
if (pad == 1)
pos--;
else if (pad == 2)
pos -= 2;
else {
/* Invalid padding */
free(out);
return NULL;
}
break;
}
}
}
*out_len = pos - out;
return out;
}

37
base64.h Normal file
View File

@ -0,0 +1,37 @@
#ifndef _OSAUTH2LIB_BASE64_
#define _OSAUTH2LIB_BASE64_
#include <stdint.h>
#include <stddef.h>
/**
* base64_encode - Base64 encode
* @src: Data to be encoded
* @len: Length of the data to be encoded
* @out_len: Pointer to output length variable, or %NULL if not used
* Returns: Allocated buffer of out_len bytes of encoded data,
* or %NULL on failure
*
* Caller is responsible for freeing the returned buffer. Returned buffer is
* nul terminated to make it easier to use as a C string. The nul terminator is
* not included in out_len.
*/
char *base64_encode(const uint8_t* src, size_t len, size_t* out_len);
/**
* base64_decode - Base64 decode
* @src: Data to be decoded
* @len: Length of the data to be decoded
* @out_len: Pointer to output length variable
* Returns: Allocated buffer of out_len bytes of decoded data,
* or %NULL on failure
*
* Caller is responsible for freeing the returned buffer.
*/
uint8_t *base64_decode(const char* src, size_t len, size_t* out_len);
char *base64url_encode(const uint8_t* src, size_t len, size_t* out_len);
uint8_t *base64url_decode(const char* src, size_t len, size_t* out_len);
#endif //_OSAUTH2LIB_BASE64_

176
log.c Normal file
View File

@ -0,0 +1,176 @@
#include "log.h"
#include <assert.h>
#include <stdarg.h>
#include <time.h>
#include <openssl/err.h>
#include "util/sorted_str_set.h"
static const char* LOG_LEVEL_STRING_TABLE[] = {
"ERROR",
"WARN",
"INFO",
"DEBUG",
"TRACE"
};
// Maybe add __attribute__((constructor)) in the future if available to skip _config_initialized
static bool _user_default_config_initialized = false;
static struct log_config_s _user_default_config;
static sorted_str_set* _module_configs = NULL;
static inline struct log_config_s _default_config() {
if (_user_default_config_initialized) {
return _user_default_config;
}
return (struct log_config_s) {
.loc_error = LOG_LOC_STDS(stderr),
.loc_warn = LOG_LOC_STDS(stderr),
.loc_info = LOG_LOC_STDS(stdout),
.loc_debug = LOG_LOC_NONE(),
.loc_trace = LOG_LOC_NONE(),
.timestamp = true,
};
}
void log_set_default_config(struct log_config_s config) {
_user_default_config = config;
_user_default_config_initialized = true;
}
struct log_config_s* log_get_config(const char* module_name) {
assert(module_name != NULL);
if (!_module_configs) {
if (!(_module_configs = sorted_str_set_new(sizeof(struct log_config_s)))) {
return NULL;
}
}
struct log_config_s default_conf = _default_config();
return sorted_str_set_insert(_module_configs, module_name, &default_conf);
}
void log_log(const char* module, enum LOG_LEVEL log_level, const char* fmt, ...) {
assert(module != NULL);
struct log_config_s* config = log_get_config(module);
assert(config);
struct log_loc* dest_loc;
switch (log_level) {
case LOG_ERR: dest_loc = &config->loc_error; break;
case LOG_WARN: dest_loc = &config->loc_warn; break;
case LOG_INFO: dest_loc = &config->loc_info; break;
case LOG_DEBUG: dest_loc = &config->loc_debug; break;
case LOG_TRACE: dest_loc = &config->loc_trace; break;
default: return;
}
if (dest_loc->type == LOG_LOC_TYPE_DISABLED) {
return;
} else if (dest_loc->type == LOG_LOC_TYPE_FS) {
if (!dest_loc->location) {
if (!(dest_loc->location = fopen(dest_loc->file_name, "a"))) {
return;
}
}
} else if (dest_loc->type != LOG_LOC_TYPE_STDS) {
return;
}
FILE* dest = dest_loc->location;
va_list varargs;
va_start(varargs, fmt);
if (config->timestamp) {
struct timespec utc_time;
struct tm local_time;
char time_str_buf[10+1+8+1];
timespec_get(&utc_time, TIME_UTC);
localtime_r(&(utc_time.tv_sec), &local_time);
if(strftime(time_str_buf, sizeof(time_str_buf), "%F %T", &local_time)) {
fprintf(dest, "[%s.%9lu] ", time_str_buf, utc_time.tv_nsec);
}
}
fprintf(dest, "[%s] %s: ", LOG_LEVEL_STRING_TABLE[log_level], module);
vfprintf(dest, fmt, varargs);
fprintf(dest, "\n");
if (dest_loc->type == LOG_LOC_TYPE_FS && !dest_loc->keep_file_open) {
fclose(dest_loc->location);
dest_loc->location = NULL;
}
}
void log_log_ssl_err(const char* module, enum LOG_LEVEL log_level, const char* fmt, ...) {
assert(module != NULL);
struct log_config_s* config = log_get_config(module);
assert(config);
struct log_loc* dest_loc;
switch (log_level) {
case LOG_ERR: dest_loc = &config->loc_error; break;
case LOG_WARN: dest_loc = &config->loc_warn; break;
case LOG_INFO: dest_loc = &config->loc_info; break;
case LOG_DEBUG: dest_loc = &config->loc_debug; break;
case LOG_TRACE: dest_loc = &config->loc_trace; break;
default: return;
}
if (dest_loc->type == LOG_LOC_TYPE_DISABLED) {
return;
} else if (dest_loc->type == LOG_LOC_TYPE_FS) {
if (!dest_loc->location) {
if (!(dest_loc->location = fopen(dest_loc->file_name, "a"))) {
return;
}
}
} else if (dest_loc->type != LOG_LOC_TYPE_STDS) {
return;
}
FILE* dest = dest_loc->location;
va_list varargs;
va_start(varargs, fmt);
if (config->timestamp) {
struct timespec utc_time;
struct tm local_time;
char time_str_buf[10+1+8+1];
timespec_get(&utc_time, TIME_UTC);
localtime_r(&(utc_time.tv_sec), &local_time);
if(strftime(time_str_buf, sizeof(time_str_buf), "%F %T", &local_time)) {
fprintf(dest, "[%s.%9lu] ", time_str_buf, utc_time.tv_nsec);
}
}
fprintf(dest, "[%s] %s: ", LOG_LEVEL_STRING_TABLE[log_level], module);
vfprintf(dest, fmt, varargs);
fprintf(dest, "\n");
ERR_print_errors_fp(dest);
if (dest_loc->type == LOG_LOC_TYPE_FS && !dest_loc->keep_file_open) {
fclose(dest_loc->location);
dest_loc->location = NULL;
}
}

97
log.c.old Normal file
View File

@ -0,0 +1,97 @@
#include "log.h"
#include <assert.h>
#include <stdarg.h>
#include <time.h>
#include <openssl/err.h>
static const char* LOG_LEVEL_STRING_TABLE[] = {
"ERROR",
"WARN",
"INFO",
"DEBUG",
"TRACE"
};
// Maybe add __attribute__((constructor)) in the future if available to skip _config_initialized
static bool _user_default_config_initialized = false;
static struct log_config_s _user_default_config;
static inline struct log_config_s _default_config() {
if (_user_default_config_initialized) {
return _user_default_config;
}
return (struct log_config_s) {
.loc_error = stderr,
.loc_warn = stderr,
.loc_info = stdout,
.loc_debug = NULL,
.loc_trace = NULL,
.timestamp = true,
};
}
void log_set_default_config(struct log_config_s config) {
_user_default_config = config;
_user_default_config_initialized = true;
}
struct logger_s log_new(const char* module_name) {
return (struct logger_s) {
.module_name = module_name,
.config = _default_config(),
};
}
void log_log(const struct logger_s* logger, enum LOG_LEVEL log_level, const char* fmt, ...) {
assert(logger != NULL);
FILE* dest = NULL;
switch (log_level) {
case LOG_ERR: dest = logger->config.loc_error; break;
case LOG_WARN: dest = logger->config.loc_warn; break;
case LOG_INFO: dest = logger->config.loc_info; break;
case LOG_DEBUG: dest = logger->config.loc_debug; break;
case LOG_TRACE: dest = logger->config.loc_trace; break;
default: break;
}
if (dest == NULL) {
return;
}
va_list varargs;
va_start(varargs, fmt);
if (logger->config.timestamp) {
struct timespec utc_time;
struct tm local_time;
char time_str_buf[10+1+8+1];
timespec_get(&utc_time, TIME_UTC);
localtime_r(&(utc_time.tv_sec), &local_time);
if(strftime(time_str_buf, sizeof(time_str_buf), "%F %T", &local_time)) {
fprintf(dest, "[%s.%9lu] ", time_str_buf, utc_time.tv_nsec);
}
}
fprintf(dest, "[%s] %s: ", LOG_LEVEL_STRING_TABLE[log_level], logger->module_name);
vfprintf(dest, fmt, varargs);
fprintf(dest, "\n");
}
void log_log_ssl_err(const struct logger_s* logger, enum LOG_LEVEL log_level, unsigned long error, const char* message_fmt, ...) {
char error_buf[256];
ERR_error_string_n(ERR_get_error(), error_buf, sizeof(error_buf));
va_list varargs;
va_start(varargs, message_fmt);
log_log(logger, log_level, message_fmt, varargs);
log_log(logger, log_level, "%s", error_buf);
}

54
log.h Normal file
View File

@ -0,0 +1,54 @@
#ifndef _OAUTH2LIB_LOG_
#define _OAUTH2LIB_LOG_
#include <stdbool.h>
#include <stdio.h>
enum LOG_LEVEL {
LOG_ERR,
LOG_WARN,
LOG_INFO,
LOG_DEBUG,
LOG_TRACE,
};
enum LOG_LOC_TYPE {
LOG_LOC_TYPE_DISABLED,
LOG_LOC_TYPE_FS,
LOG_LOC_TYPE_STDS
};
struct log_loc {
enum LOG_LOC_TYPE type;
FILE* location;
const char* file_name;
bool keep_file_open;
};
#define LOG_LOC_FILE(file_location, keep_open) ((struct log_loc) { .type=LOG_LOC_TYPE_FS, .location=NULL, .file_name=file_location, .keep_file_open=keep_open })
#define LOG_LOC_STDS(stream) ((struct log_loc) { .type=LOG_LOC_TYPE_STDS, .location=stream, .file_name=NULL })
#define LOG_LOC_NONE() ((struct log_loc) { .type=LOG_LOC_TYPE_DISABLED })
struct log_config_s {
struct log_loc loc_error;
struct log_loc loc_warn;
struct log_loc loc_info;
struct log_loc loc_debug;
struct log_loc loc_trace;
bool timestamp;
};
struct logger_s {
const char* module_name;
struct log_config_s config;
};
void log_set_default_config(struct log_config_s config);
struct log_config_s* log_get_config(const char* module_name);
void log_log(const char* module, enum LOG_LEVEL log_level, const char* fmt, ...);
void log_log_ssl_err(const char* module, enum LOG_LEVEL log_level, const char* message_fmt, ...);
#endif //_OAUTH2LIB_LOG_

38
log.h.old Normal file
View File

@ -0,0 +1,38 @@
#ifndef _OAUTH2LIB_LOG_
#define _OAUTH2LIB_LOG_
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
enum LOG_LEVEL {
LOG_ERR,
LOG_WARN,
LOG_INFO,
LOG_DEBUG,
LOG_TRACE,
};
struct log_config_s {
FILE* loc_error;
FILE* loc_warn;
FILE* loc_info;
FILE* loc_debug;
FILE* loc_trace;
bool timestamp;
};
struct logger_s {
const char* module_name;
struct log_config_s config;
};
void log_set_default_config(struct log_config_s config);
struct logger_s log_new(const char* module_name);
void log_log(const struct logger_s* logger, enum LOG_LEVEL log_level, const char* fmt, ...);
void log_log_ssl_err(const struct logger_s* logger, enum LOG_LEVEL log_level, unsigned long error, const char* message_fmt, ...);
#endif //_OAUTH2LIB_LOG_

23
main.c Normal file
View File

@ -0,0 +1,23 @@
#include <stdlib.h>
#include "networking.h"
#include "pkce.h"
#include "log.h"
#include "server.h"
#include "ssl.h"
#include <openssl/ssl.h>
int main(int argc, const char* argv[]) {
/*SSL_CTX* ctx = SSL_CTX_new(TLS_server_method());
ssl_use_ppcerts(ctx);*/
// tcp_get(argv[1]);
//do_pkce();
server_start();
return 0;
}

131
main.c.old Normal file
View File

@ -0,0 +1,131 @@
#include <errno.h>
#include <netdb.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <netinet/in.h>
//#include "dns.h"
#define WEB_ADDR "www.onedrive.com"
struct socket_data_chunk {
char data[4096];
struct socket_data_chunk *next;
};
void free_socketdata(struct socket_data_chunk* data_chunk) {
for (struct socket_data_chunk *iter = data_chunk; iter != NULL; ) {
struct socket_data_chunk *data = iter;
iter = iter->next;
free(data);
}
}
int strncmp_end(const char *str, size_t max_strlen, const char *suffix, size_t max_suffix_len)
{
if (!str || !suffix)
return 0;
size_t lenstr = strnlen(str, max_strlen);
size_t lensuffix = strnlen(suffix, max_suffix_len);
if (lensuffix > lenstr)
return 0;
return strncmp(str + lenstr - lensuffix, suffix, lensuffix);
}
int main() {
//oauth2_dns_resolve("onedrive.com");
struct addrinfo resolve_hints = {
.ai_family = AF_INET, //AF_UNSPEC, // IPv4 or IPv6
.ai_socktype = SOCK_STREAM,
.ai_flags = 0,
.ai_protocol = 0
};
struct addrinfo *resolve_result;
int get_addr_res = getaddrinfo(WEB_ADDR, NULL, &resolve_hints, &resolve_result);
if (get_addr_res != 0) {
printf("getaddrinfo failed: %d\n", get_addr_res);
exit(-1);
}
char ip_addr[50];
printf("IP address: %s\n", inet_ntop(AF_INET, &((struct sockaddr_in*) resolve_result->ai_addr)->sin_addr, ip_addr, sizeof(ip_addr)));
printf("Port: %d\n", ntohs(((struct sockaddr_in*)resolve_result->ai_addr)->sin_port));
((struct sockaddr_in*)resolve_result->ai_addr)->sin_port = htons(80);
int sock_desc = socket(AF_INET, SOCK_STREAM, 0);
if (sock_desc == -1) {
printf("Socket creation failed. Errno: %d\n", errno);
exit(-1);
}
printf("created socket\n");
struct sockaddr_in server_addr = *(struct sockaddr_in *)resolve_result->ai_addr;
// struct sockaddr_in server_addr = {
// .sin_family = AF_INET,
// .sin_addr = inet_addr("127.0.0.1"),
// .sin_port = htons(8080)
// };
if (connect(sock_desc, (struct sockaddr*)&server_addr, sizeof(server_addr)) != 0) {
printf("Socket connection failed. Errno: %d\n", errno);
exit(-1);
close(sock_desc);
}
printf("Connected to address\n");
const char *msg =
"GET / HTTP/1.1\r\n"
"Host: " WEB_ADDR "\r\n"
"\r\n";
write(sock_desc, msg, strlen(msg));
printf("sent: \n%s\n", msg);
struct socket_data_chunk *data = NULL;
struct socket_data_chunk **it = &data;
do {
struct socket_data_chunk *nit = calloc(1, sizeof(struct socket_data_chunk));
if (!nit) {
printf("failed to allocate data for receive\n");
exit(-1);
}
if (*it) {
(*it)->next = nit;
} else {
*it = nit;
}
it = &nit;
ssize_t bytes_read = recv(sock_desc, &nit->data, sizeof(nit->data), 0);
if (bytes_read == -1) {
printf("Failed to recieve data. Errno: %d", errno);
exit(-1);
}
printf("received (%ld) bytes:\n%s\n", bytes_read, nit->data);
//} while (!strstr((*it)->data, "\r\n\r\n"));
} while(strncmp_end(data->data, sizeof((*it)->data), "\r\n\r\n", 5));
printf("Finishing\n");
close(sock_desc);
freeaddrinfo(resolve_result);
free_socketdata(data);
return 0;
}

221
networking.c Normal file
View File

@ -0,0 +1,221 @@
#include "networking.h"
#include <openssl/bio.h>
#include <stdbool.h>
#include <stdio.h>
#include <unistd.h>
#include <string.h>
#include <netdb.h>
#include <sys/socket.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include "ssl.h"
typedef int socket_fd;
static SSL_CTX *ssl_ctx = NULL;
char *strstr_fixn(const char *haystack, const char *needle, size_t len) {
size_t needle_len;
if (0 == (needle_len = strnlen(needle, len)))
return (char *)haystack;
for (size_t i = 0; i <= len - needle_len; i++)
{
if ((haystack[0] == needle[0]) &&
(0 == strncmp(haystack, needle, needle_len)))
return (char *)haystack;
haystack++;
}
return NULL;
}
static socket_fd connect_to_http(const char* hostname) {
struct addrinfo resolve_hints = {
.ai_family = AF_UNSPEC,
.ai_socktype = SOCK_STREAM,
};
struct addrinfo* resolve_result;
int result = getaddrinfo(hostname, "https", &resolve_hints, &resolve_result);
if (result != 0) {
fprintf(stderr, "Failed to get IP address of \"%s\" for an http connection. Error: %s\n", hostname, gai_strerror(result));
return -1;
}
socket_fd sock_desc = -1;
for (struct addrinfo *it = resolve_result; it != NULL; it = it->ai_next) {
sock_desc = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
if (sock_desc == -1) {
continue;
}
if (connect(sock_desc, it->ai_addr, it->ai_addrlen) != -1) {
uint16_t port = ntohs(
(it->ai_family == AF_INET)
? ((struct sockaddr_in*) it->ai_addr)->sin_port
: ((struct sockaddr_in6*) it->ai_addr)->sin6_port);
printf("Connected to \"%s\" on port %d\n", hostname, port);
break;
}
close(sock_desc);
}
freeaddrinfo(resolve_result);
return sock_desc;
}
void tcp_get(const char* addr) {
initialize_ssl();
printf("initialized ssl\n");
socket_fd sock_fd = connect_to_http(addr);
printf("Got socket: %d\n", sock_fd);
SSL* ssl_conn = SSL_new(ssl_ctx);
if (!ssl_conn) {
fprintf(stderr, "Failed to create an SSL connection provider.\n");
ERR_print_errors_fp(stderr);
close(sock_fd);
return;
}
SSL_set_fd(ssl_conn, sock_fd);
if (!SSL_set_tlsext_host_name(ssl_conn, addr)) {
fprintf(stderr, "Failed to set SNI hostname.\n");
ERR_print_errors_fp(stderr);
SSL_free(ssl_conn);
return;
}
if (!SSL_set1_host(ssl_conn, addr)) {
fprintf(stderr, "Failed to set certificate hostname.\n");
ERR_print_errors_fp(stderr);
SSL_free(ssl_conn);
return;
}
if (SSL_connect(ssl_conn) < 1) {
fprintf(stderr, "Failed to connect to server.\n");
if (SSL_get_verify_result(ssl_conn) != X509_V_OK) {
fprintf(stderr, "Verification error: %s\n", X509_verify_cert_error_string(SSL_get_verify_result(ssl_conn)));
} else {
ERR_print_errors_fp(stderr);
}
SSL_free(ssl_conn);
return;
}
printf("Connected to peer.\n");
const char* HTTP_HEAD = "GET / HTTP/1.1\r\n"
"Connection: close\r\n"
"Host: ";
const char* HTTP_HEAD_END = "\r\n\r\n";
const char* msg[] = { HTTP_HEAD, addr, HTTP_HEAD_END };
for (size_t i = 0; i < sizeof(msg)/sizeof(msg[0]); ++i) {
if (!SSL_write(ssl_conn, msg[i], strlen(msg[i]))) {
fprintf(stderr, "Failed to send message: \"%s\"\n", msg[i]);
ERR_print_errors_fp(stderr);
SSL_free(ssl_conn);
return;
}
}
printf("Sent HTTP request to peer.\n");
size_t read_bytes;
char read_buf[160];
enum { HTTP_CONTENT_LENGTH, HTTP_CONTENT_CHUNKED } http_content_transmit_type;
size_t total_read_bytes;
size_t expected_bytes = -1;
bool body = false;
char last_3_of_prev[3] = {};
printf("\n");
while (SSL_read_ex(ssl_conn, read_buf, sizeof(read_buf), &read_bytes)) {
total_read_bytes += read_bytes;
if ((http_content_transmit_type == HTTP_CONTENT_CHUNKED && read_buf[0] == '0')
|| http_content_transmit_type == HTTP_CONTENT_LENGTH && total_read_bytes == expected_bytes) {
SSL_set_options(ssl_conn, SSL_OP_IGNORE_UNEXPECTED_EOF);
}
if (!body) {
const char* content_length_str = "Content-Length: ";
const char* transfer_encoding_str = "Transfer-Encoding: chunked";
char* content_length_start = strstr_fixn(read_buf, content_length_str, read_bytes);
if (content_length_start != NULL) {
http_content_transmit_type = HTTP_CONTENT_LENGTH;
content_length_start += strlen(content_length_str);
expected_bytes = strtol(content_length_start, NULL, 10);
}
if (strstr_fixn(read_buf, transfer_encoding_str, read_bytes)) {
http_content_transmit_type = HTTP_CONTENT_CHUNKED;
}
char concat[] = { last_3_of_prev[0], last_3_of_prev[1], last_3_of_prev[2], read_buf[0], read_buf[1], read_buf[2], 0 };
char *concat_start;
if ((concat_start = strstr(concat, HTTP_HEAD_END))) {
body = true;
total_read_bytes = read_bytes - (concat_start - concat + 1);
} else {
char* body_start = strstr_fixn(read_buf, HTTP_HEAD_END, read_bytes);
if (body_start) {
body = true;
total_read_bytes = read_bytes - (body_start + strlen(HTTP_HEAD_END) - read_buf);
}
}
}
fwrite(read_buf, 1, read_bytes, stdout);
}
printf("\n");
if (SSL_get_error(ssl_conn, 0) != SSL_ERROR_ZERO_RETURN) {
fprintf(stderr, "Failed to read response.\n");
ERR_print_errors_fp(stderr);
}
printf("Recieved %ld bytes", total_read_bytes);
printf(expected_bytes != -1 ? ", expected: %ld bytes\n" : "\n", expected_bytes);
printf("Ignore unexpected eof? %s\n", SSL_get_options(ssl_conn) & SSL_OP_IGNORE_UNEXPECTED_EOF ? "true" : "false");
if (SSL_shutdown(ssl_conn) < 1) {
fprintf(stderr, "Failed to shut down connection gracefully.\n");
ERR_print_errors_fp(stderr);
}
SSL_free(ssl_conn);
printf("end\n");
}

11
networking.h Normal file
View File

@ -0,0 +1,11 @@
#ifndef _OAUTH2LIB_NETWORKING_
#define _OAUTH2LIB_NETWORKING_
struct content {
};
void tcp_get(const char* addr);
#endif //_OAUTH2LIB_NETWORKING_

145
pkce.c Normal file
View File

@ -0,0 +1,145 @@
#include "pkce.h"
#include <openssl/sha.h>
#include <stddef.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/rand.h>
#include "base64.h"
#include "log.h"
#include "ssl.h"
#define MODULE_NAME "PKCE"
#define CODE_VERIFIER_LENGTH 32 // the length of the binary random string before base64url encoding
struct code_verifier_s {
size_t length;
char* key;
};
static inline void free_code_verifier(struct code_verifier_s* ptr) { free(ptr->key); }
static struct code_verifier_s generate_code_verifier() {
uint8_t code_buffer[CODE_VERIFIER_LENGTH];
int gen_result = RAND_priv_bytes(code_buffer, sizeof(code_buffer));
if (gen_result != 1) {
log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to generate code_verifier.\n");
return (struct code_verifier_s) { 0, NULL };
}
size_t encoded_size = 0;
char* encoded = base64url_encode(code_buffer, sizeof(code_buffer), &encoded_size);
if (!encoded) {
log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to encode code_verifier.\n");
return (struct code_verifier_s) { 0, NULL };
}
// trim trailing =\n
encoded_size -= 2;
encoded[encoded_size] = '\0';
return (struct code_verifier_s) { encoded_size, encoded };
}
static uint8_t* sha256_encode(const char* str, size_t str_len) {
EVP_MD_CTX* hash_ctx = EVP_MD_CTX_new();
if (!hash_ctx) {
log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to create Crypto context.\n");
return NULL;
}
EVP_MD* sha256_ctx = ssl_get_sha256();
if (!sha256_ctx) {
EVP_MD_CTX_free(hash_ctx);
return NULL;
}
if (!EVP_DigestInit_ex(hash_ctx, sha256_ctx, NULL)) {
log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to initialize sha256 digest.\n");
EVP_MD_CTX_free(hash_ctx);
return NULL;
}
if (!EVP_DigestUpdate(hash_ctx, str, str_len)) {
log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to set message to digest.\n");
EVP_MD_CTX_free(hash_ctx);
return NULL;
}
uint8_t *hash_buf = malloc(EVP_MD_get_size(sha256_ctx));
if (!hash_buf) {
log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to allocate memory for digest output.\n");
EVP_MD_CTX_free(hash_ctx);
return NULL;
}
unsigned int len = 0;
if (!EVP_DigestFinal_ex(hash_ctx, hash_buf, &len)) {
log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to compute hash.\n");
EVP_MD_CTX_free(hash_ctx);
return NULL;
}
EVP_MD_CTX_free(hash_ctx);
return hash_buf;
}
char* _construct_get(char* server_url, char* client_id, char* redirect_uri, char* scope, char* state) {
static char msg_template[] = "GET %s?response_type=code&client_id=%s&state=%s&redirect_uri=%s HTTP/1.1\r\nHost: localhost\r\n\r\n";
size_t len = strlen(msg_template) - 4*2 + strlen(server_url)+strlen(client_id)+strlen(redirect_uri)+strlen(scope)+strlen(state) + 1;
char* buf = malloc(len);
if (!buf) { return NULL; }
if (snprintf(buf, len, msg_template, server_url, client_id, state, redirect_uri) < 0) {
free(buf);
return NULL;
}
return buf;
}
void do_pkce() {
log_log("PKCE", LOG_INFO, "generating code verifier");
struct code_verifier_s code_verifier = generate_code_verifier();
log_log("PKCE", LOG_DEBUG, "generated verifer: %s\n", code_verifier.key);
log_log("PKCE", LOG_INFO, "generating code challenge");
uint8_t* code_challenge_sha256 = sha256_encode(code_verifier.key, code_verifier.length);
if (!code_challenge_sha256) {
log_log("PKCE", LOG_ERR, "Failed to encode code challenge");
goto pkce_cleanup;
}
size_t code_challenge_length = 0;
char* code_challenge = base64url_encode(code_challenge_sha256, EVP_MD_get_size(ssl_get_sha256()), &code_challenge_length);
code_challenge_length -= 2;
code_challenge[code_challenge_length] = '\0';
log_log("PKCE", LOG_DEBUG, "generated code challenge: %s\n", code_challenge);
pkce_cleanup:
free_code_verifier(&code_verifier);
free(code_challenge);
}

6
pkce.h Normal file
View File

@ -0,0 +1,6 @@
#ifndef _OAUTH2LIB_PKCE_
#define _OAUTH2LIB_PKCE_
void do_pkce();
#endif //_OAUTH2LIB_PKCE_

141
server.c Normal file
View File

@ -0,0 +1,141 @@
#include "server.h"
#include <stdatomic.h>
#include <unistd.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <openssl/err.h>
#include "log.h"
#include "ssl.h"
#include "util/thread_queue.h"
#define MODULE_NAME "login_server"
#define SERVER_PORT 443
static const char message[] = "<html><body><h1>Hello World!</h1></body></html>";
static SSL_CTX* _ssl_server_ctx = NULL;
static struct logger_s _logger;
static pthread_t *_server_thread = NULL;
static struct threadqueue _server_thread_queue;
_Atomic enum server_state _server_state = SERVER_NOT_STARTED;
static enum server_start_return _server_thread_func(void* data) {
atomic_store(&_server_state, SERVER_STARTING);
struct sockaddr_in sockaddr = {
.sin_family = AF_INET,
.sin_port = htons(SERVER_PORT),
.sin_addr = {
.s_addr = htonl(INADDR_ANY)
}
};
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
log_log(MODULE_NAME, LOG_ERR, "Failed to create socket: %s", strerror(errno));
atomic_store(&_server_state, SERVER_FAILED);
return SERVER_FAIL;
}
if (bind(sock, (struct sockaddr*)&sockaddr, sizeof(sockaddr)) < 0) {
log_log(MODULE_NAME, LOG_ERR, "Failed to bind socket: %s", strerror(errno));
atomic_store(&_server_state, SERVER_FAILED);
close(sock);
return SERVER_FAIL;
}
if (listen(sock, 1) != 0) {
log_log(MODULE_NAME, LOG_ERR, "Failed to start listening on socket: %s", strerror(errno));
atomic_store(&_server_state, SERVER_FAILED);
close(sock);
return SERVER_FAIL;
}
log_log(MODULE_NAME, LOG_INFO, "Waiting for incoming connections...");
atomic_store(&_server_state, SERVER_RUNNING);
while(1) {
struct sockaddr addr;
socklen_t addr_len = sizeof(addr);
int client_fd = accept(sock, &addr, &addr_len);
if (client_fd < 0) {
log_log(MODULE_NAME, LOG_ERR, "Failed to accept an incoming connection: %s", strerror(errno));
continue;
}
SSL* ssl_conn = SSL_new(_ssl_server_ctx);
SSL_set_fd(ssl_conn, client_fd);
if (SSL_accept(ssl_conn) <= 0) {
log_log_ssl_err(MODULE_NAME, LOG_ERR, "SSL Handshake failed.");
} else {
size_t bufsize = 1024;
char* buf = malloc(bufsize);
if (SSL_read(ssl_conn, buf, bufsize) > 0) {
log_log(MODULE_NAME, LOG_INFO, "Received message: %s", buf);
}
SSL_write(ssl_conn, message, sizeof(message));
}
SSL_shutdown(ssl_conn);
SSL_free(ssl_conn);
close(client_fd);
}
_server_thread_shutdown:
close(sock);
}
enum server_start_return server_start() {
initialize_ssl();
if (!_ssl_server_ctx) {
_ssl_server_ctx = SSL_CTX_new(TLS_server_method());
if (!_ssl_server_ctx) {
log_log(MODULE_NAME, LOG_ERR, "Failed to create SSL CTX");
return SERVER_FAIL;
}
if (!ssl_use_ppcerts(_ssl_server_ctx)) {
log_log(MODULE_NAME, LOG_ERR, "Failed to set certificates.");
SSL_CTX_free(_ssl_server_ctx);
_ssl_server_ctx = NULL;
return SERVER_FAIL;
}
}
if (!_server_thread) {
_server_thread = malloc(sizeof(pthread_t));
if (!_server_thread) {
log_log(MODULE_NAME, LOG_ERR, "Failed to allocate resources for server_thread");
return SERVER_FAIL;
}
thread_queue_init(&_server_thread_queue);
pthread_create(_server_thread, NULL, (void* (*)(void*))_server_thread_func, NULL);
pthread_join(*_server_thread, NULL);
free(_server_thread);
_server_thread = NULL;
}
return SERVER_OK;
}

20
server.h Normal file
View File

@ -0,0 +1,20 @@
#ifndef _OAUTH2LIB_SERVER_
#define _OAUTH2LIB_SERVER_
enum server_start_return {
SERVER_FAIL = 0,
SERVER_OK = 1,
SERVER_STARTED
};
enum server_state {
SERVER_FAILED,
SERVER_NOT_STARTED,
SERVER_STARTING,
SERVER_RUNNING,
SERVER_FINISHING,
};
enum server_start_return server_start();
#endif //_OAUTH2LIB_SERVER_

183
ssl.c Normal file
View File

@ -0,0 +1,183 @@
#include "ssl.h"
#include "log.h"
#include <errno.h>
#include <unistd.h>
#include <openssl/core_names.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/ssl.h>
#include <openssl/x509.h>
#define LOG_NAME "SSL"
static bool _initialized = false;
static SSL_CTX* _ssl_ctx = NULL;
static EVP_MD* _sha256_ctx = NULL;
static const char _pkey_file_name[] = "key.pem";
static const char _cert_file_name[] = "cert.pem";
static const char _passphrase[] = "This is absolutely insecure.";
static int _pem_passphrase_cb(char* buf, int size, int rwflag, void* userdata) {
int len = strlen(_passphrase);
memcpy(buf, _passphrase, len < size ? len : size);
return len;
}
bool _initialize_ssl() {
SSL_library_init();
SSL_load_error_strings();
_ssl_ctx = SSL_CTX_new(TLS_client_method());
if (_ssl_ctx == NULL) {
fprintf(stderr, "Failed to initialize Openssl.\n");
return false;
}
SSL_CTX_set_verify(_ssl_ctx, SSL_VERIFY_PEER, NULL);
if (!SSL_CTX_set_default_verify_paths(_ssl_ctx)) {
fprintf(stderr, "Failed to set default paths for openssl.\n");
return false;
}
if (!SSL_CTX_set_min_proto_version(_ssl_ctx, TLS1_3_VERSION)) {
fprintf(stderr, "Failed to set minimum TLS version.\n");
return false;
}
return true;
}
bool initialize_ssl() {
if (!_initialized) {
_initialized = _initialize_ssl();
}
return _initialized;
}
void ssl_close_ssl() {
SSL_CTX_free(_ssl_ctx);
EVP_MD_free(_sha256_ctx);
}
SSL_CTX* ssl_get_ctx() { return _ssl_ctx; }
EVP_MD* ssl_get_sha256() { return _sha256_ctx; }
bool ssl_generate_cert() {
EVP_PKEY* pkey = NULL;
X509* cert = NULL;
X509_NAME* cert_issuer;
bool success = false;
pkey = EVP_EC_gen("P-256");
if (pkey == NULL) {
log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to generate private key for certificate.");
goto generate_cert_cleanup;
}
cert = X509_new();
if (cert == NULL) {
log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to generate certificate.");
goto generate_cert_cleanup;
}
ASN1_INTEGER_set(X509_get_serialNumber(cert), 1);
X509_gmtime_adj(X509_get_notBefore(cert), 0);
X509_gmtime_adj(X509_get_notAfter(cert), 31536000L);
cert_issuer = X509_get_subject_name(cert);
if (cert_issuer == NULL ||
!( X509_NAME_add_entry_by_txt(cert_issuer, "C", MBSTRING_ASC, (unsigned char*)"AT", -1, -1, 0)
&& X509_NAME_add_entry_by_txt(cert_issuer, "O", MBSTRING_ASC, (unsigned char*)"me", -1, -1, 0)
&& X509_NAME_add_entry_by_txt(cert_issuer, "CN", MBSTRING_ASC, (unsigned char*)"localhost", -1, -1, 0))) {
log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to create issuer name.");
goto generate_cert_cleanup;
}
if (!X509_set_issuer_name(cert, cert_issuer)) {
log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to set issuer name on certificate.");
goto generate_cert_cleanup;
}
if (!X509_set_pubkey(cert, pkey)) {
log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to set private key on certificate");
goto generate_cert_cleanup;
}
if (!X509_sign(cert, pkey, EVP_sha256())) {
log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to sign certificate.");
goto generate_cert_cleanup;
}
FILE* key_file = NULL;
key_file = fopen(_pkey_file_name, "wb");
if (!PEM_write_PrivateKey(key_file, pkey, EVP_des_ede3_cbc(), (unsigned char*)_passphrase, sizeof(_passphrase), NULL, NULL)) {
log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to write private certificate to disk.");
fclose(key_file);
goto generate_cert_cleanup;
}
fclose(key_file);
key_file = fopen(_cert_file_name, "wb");
if (!PEM_write_X509(key_file, cert)) {
log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to write certificate to disk.");
fclose(key_file);
goto generate_cert_cleanup;
}
fclose(key_file);
success = true;
generate_cert_cleanup:
EVP_PKEY_free(pkey);
X509_free(cert);
X509_NAME_free(cert_issuer);
return success;
}
bool ssl_use_ppcerts(SSL_CTX* ctx) {
if (access(_cert_file_name, R_OK) != 0 && access(_pkey_file_name, R_OK) != 0) {
log_log(LOG_NAME, LOG_INFO, "Could not access certificates - Generating new ones");
bool res = ssl_generate_cert();
if (!res) {
log_log(LOG_NAME, LOG_ERR, "Failed to generate certificates");
return false;
}
}
if (SSL_CTX_use_certificate_chain_file(ctx, _cert_file_name) != 1) {
log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to load SSL certificate.");
return false;
}
SSL_CTX_set_default_passwd_cb(ctx, _pem_passphrase_cb);
if (SSL_CTX_use_PrivateKey_file(ctx, _pkey_file_name, SSL_FILETYPE_PEM) != 1) {
log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to load SSL Private Key.");
return false;
}
SSL_CTX_set_default_passwd_cb(ctx, NULL);
return true;
}

17
ssl.h Normal file
View File

@ -0,0 +1,17 @@
#ifndef _OAUTH2LIB_SSL_
#define _OAUTH2LIB_SSL_
#include <stdbool.h>
#include <openssl/ssl.h>
bool initialize_ssl();
SSL_CTX* ssl_get_ctx();
EVP_MD* ssl_get_sha256();
bool ssl_generate_cert();
bool ssl_use_ppcerts(SSL_CTX* ctx);
#endif //_OAUTH2LIB_SSL_

BIN
util/a.out Executable file

Binary file not shown.

174
util/sorted_str_set.c Normal file
View File

@ -0,0 +1,174 @@
#include "sorted_str_set.h"
#include <assert.h>
#include <stdlib.h>
#include <string.h>
#define SORTED_STR_SET_INITIAL_CAPACITY 8
#define SORTED_STR_SET_GROWTH_FACTOR 1.5
static ssize_t _get(const sorted_str_set* set, const char* key, int* out_res) {
ssize_t index = 0;
ssize_t low = 0;
ssize_t upp = set->size - 1;
int res = 0;
while (low <= upp) {
index = low + (upp-low)/2;
res = strcmp(key, set->data[index].key);
if (res == 0) {
if (out_res) *out_res = res;
return index;
}
if (res < 0) {
upp = index - 1;
} else {
low = index + 1;
}
}
if (res > 0) {
++index;
}
return -1;
}
sorted_str_set* sorted_str_set_new(size_t element_size) {
sorted_str_set* set = malloc(sizeof(sorted_str_set));
if (!set) { return NULL; }
struct sorted_str_set_pair* data = malloc(SORTED_STR_SET_INITIAL_CAPACITY * sizeof(struct sorted_str_set_pair));
if (!data) {
free(set);
return NULL;
}
*set = (sorted_str_set) {
.element_size = element_size,
.size = 0,
.capacity = SORTED_STR_SET_INITIAL_CAPACITY,
.data = data,
};
return set;
}
void sorted_str_set_free(sorted_str_set* set) {
assert(set); //, "set must not be NULL.");
for (int i = 0; i < set->size; ++i) {
free(set->data[i].key);
free(set->data[i].element);
}
free(set->data);
free(set);
}
void* sorted_str_set_get(const sorted_str_set* set, const char* key) {
assert(set); //, "set must not be null");
assert(key); //, "key must not be null");
if (set->size == 0) {
return NULL;
}
ssize_t i = _get(set, key, NULL);
if (i == -1) {
return NULL;
} else {
return set->data[i].element;
}
}
void* sorted_str_set_insert(sorted_str_set* set, const char* key, const void* element) {
assert(set); //, "set must not be NULL");
assert(key); //, "key must not be NULL");
assert(element); //, "element must not be NULL");
size_t index = 0;
if (set->size != 0) {
ssize_t low = 0;
ssize_t upp = set->size - 1;
int res = 0;
while (low <= upp) {
index = low + (upp-low)/2;
res = strcmp(key, set->data[index].key);
if (res == 0) {
return set->data[index].element;
}
if (res < 0) {
upp = index - 1;
} else {
low = index + 1;
}
}
if (res > 0) {
++index;
}
}
if (set->size == set->capacity) {
size_t new_cap = set->capacity * SORTED_STR_SET_GROWTH_FACTOR;
struct sorted_str_set_pair* tmp = reallocarray(set->data, new_cap, sizeof(struct sorted_str_set_pair));
if (!tmp) {
return NULL;
}
set->data = tmp;
set->capacity = new_cap;
}
struct sorted_str_set_pair pair = {
.key = malloc(strlen(key) + 1),
.element = malloc(set->element_size),
};
if (!pair.key || !pair.element) {
free(pair.key);
free(pair.element);
return NULL;
}
strcpy(pair.key, key);
memcpy(pair.element, element, set->element_size);
if (set->size - index) {
memmove(&set->data[index + 1], &set->data[index], (set->size - index) * sizeof(struct sorted_str_set_pair));
}
set->data[index] = pair;
set->size++;
return set->data[index].element;
}
void sorted_str_set_remove(sorted_str_set* set, const char* key) {
assert(set);//, "set must not be NULL");
assert(key);//, "key must not be NULL");
ssize_t index = _get(set, key, NULL);
if (index == -1) {
return;
}
struct sorted_str_set_pair* loc = &set->data[index];
free(loc->key);
free(loc->element);
set->size--;
memmove(loc, loc+1, (set->size - index) * sizeof(struct sorted_str_set_pair));
}

69
util/sorted_str_set.h Normal file
View File

@ -0,0 +1,69 @@
#ifndef _OAUTH2LIB_UTIL_SORTED_SET_
#define _OAUTH2LIB_UTIL_SORTED_SET_
#include <stddef.h>
struct sorted_str_set_pair {
char* key;
void* element;
};
typedef struct {
size_t element_size;
size_t size;
size_t capacity;
struct sorted_str_set_pair* data;
} sorted_str_set;
/**
* @name sorted_str_set_new
* @description Allocates and returns a new instance of sorted_str_set. Must be freed with @sorted_str_set_free
*
* @param element_size the size of each element (not the key)
* @return the newly allocated set or NULL if allocation fails
*/
sorted_str_set* sorted_str_set_new(size_t element_size);
/**
* @name sorted_str_set_free
* @description Frees all resources of an instance of sorted_str_set. Has to be called before the end of the lifetime of the variable.
* Does not free any memory possibly pointed to by element.
*
* @param set the set to be deallocated. Must not be NULL.
*/
void sorted_str_set_free(sorted_str_set* set);
/**
* @name sorted_str_set_free
* @description Returns a pointer to the element associated with key within set
*
* @param set the set to be accessed. Must not be NULL.
* @param key the key whose value should be retrieved. Must be null-terminated.
* @return a pointer to the element associated with key or NULL if key is not in the set
*/
void* sorted_str_set_get(const sorted_str_set* set, const char* key);
/**
* @name sorted_str_set_insert
* @description Inserts a new element associated with key into set and returns a pointer to the element.
* If key is already present in set, nothing is modified and a pointer to the existing element is returned.
*
* @param set the set to be accessed. Must not be NULL.
* @param key the key for the element. Must not be NULL. Must be null-terminated.
* @param element the element to be inserted. A copy of element is made. Is expected to be the size of set->element_size. Must not be NULL.
* @return a pointer to the newly inserted element or if the key is already existing, the element associated with key.
*/
void* sorted_str_set_insert(sorted_str_set* set, const char* key, const void* element);
/**
* @name sorted_str_set_remove
* @description Removes the key and element associated by key from set. If the key is not present, does nothing.
*
* @param set the set to be accessed. Must not be NULL.
* @aram key the key whose element should be removed. Must not be NULL.
*/
void sorted_str_set_remove(sorted_str_set* set, const char* key);
#endif //_OAUTH2LIB_UTIL_SORTED_SET_

View File

@ -0,0 +1,67 @@
#include "sorted_set.h"
#include <stdio.h>
#include <string.h>
#define REF(dtype, value) ((dtype []){ (value) })
void print_set(const sorted_str_set* set) {
for (int i = 0; i < set->size; ++i) {
printf("\"%s\": %d\n", set->data[i].key, *(int*)set->data[i].element);
}
}
int main() {
sorted_str_set *set = sorted_str_set_new(sizeof(int));
printf("%p\n", sorted_str_set_get(set, "zero"));
printf("-%d\n", strcmp("-one", "one"));
sorted_str_set_insert(set, "three", REF(int, 3));
sorted_str_set_insert(set, "one", REF(int, 1));
sorted_str_set_insert(set, "-one", REF(int, -1));
sorted_str_set_insert(set, "four", REF(int, 4));
sorted_str_set_insert(set, "two", REF(int, 2));
print_set(set);
printf("\nsize: %lu\n", set->size);
printf("%d\n", *(int*)sorted_str_set_get(set, "one"));
printf("%d\n", *(int*)sorted_str_set_get(set, "two"));
printf("%d\n", *(int*)sorted_str_set_get(set, "three"));
printf("%d\n", *(int*)sorted_str_set_get(set, "four"));
printf("%d\n", *(int*)sorted_str_set_get(set, "-one"));
printf("\n%d should be 2\n", *(int*)sorted_str_set_insert(set, "two", REF(int, -2)));
printf("size: %lu (should be the same as before)\n", set->size);
sorted_str_set_insert(set, "ten", REF(int, 10));
sorted_str_set_insert(set, "eleven", REF(int, 11));
sorted_str_set_insert(set, "negative four", REF(int, -4));
sorted_str_set_insert(set, "fifty", REF(int, 50));
printf("%d\n", *(int*)sorted_str_set_get(set, "ten"));
printf("%d\n", *(int*)sorted_str_set_get(set, "eleven"));
printf("%d\n", *(int*)sorted_str_set_get(set, "negative four"));
printf("%d\n", *(int*)sorted_str_set_get(set, "fifty"));
printf("size: %lu\n", set->size);
print_set(set);
sorted_str_set_remove(set, "three");
sorted_str_set_remove(set, "one");
sorted_str_set_remove(set, "three");
sorted_str_set_remove(set, "four");
sorted_str_set_remove(set, "four");
printf("size: %lu\n", set->size);
print_set(set);
sorted_str_set_free(set);
return 0;
}

204
util/thread_queue.c Normal file
View File

@ -0,0 +1,204 @@
#include "thread_queue.h"
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <pthread.h>
#include <sys/time.h>
#define MSGPOOL_SIZE 256
struct msglist {
struct threadmsg msg;
struct msglist *next;
};
static inline struct msglist *get_msglist(struct threadqueue *queue)
{
struct msglist *tmp;
if(queue->msgpool != NULL) {
tmp = queue->msgpool;
queue->msgpool = tmp->next;
queue->msgpool_length--;
} else {
tmp = malloc(sizeof *tmp);
}
return tmp;
}
static inline void release_msglist(struct threadqueue *queue,struct msglist *node)
{
if(queue->msgpool_length > ( queue->length/8 + MSGPOOL_SIZE)) {
free(node);
} else {
node->msg.data = NULL;
node->msg.msgtype = 0;
node->next = queue->msgpool;
queue->msgpool = node;
queue->msgpool_length++;
}
if(queue->msgpool_length > (queue->length/4 + MSGPOOL_SIZE*10)) {
struct msglist *tmp = queue->msgpool;
queue->msgpool = tmp->next;
free(tmp);
queue->msgpool_length--;
}
}
int thread_queue_init(struct threadqueue *queue)
{
int ret = 0;
if (queue == NULL) {
return EINVAL;
}
memset(queue, 0, sizeof(struct threadqueue));
ret = pthread_cond_init(&queue->cond, NULL);
if (ret != 0) {
return ret;
}
ret = pthread_mutex_init(&queue->mutex, NULL);
if (ret != 0) {
pthread_cond_destroy(&queue->cond);
return ret;
}
return 0;
}
int thread_queue_add(struct threadqueue *queue, void *data, long msgtype)
{
struct msglist *newmsg;
pthread_mutex_lock(&queue->mutex);
newmsg = get_msglist(queue);
if (newmsg == NULL) {
pthread_mutex_unlock(&queue->mutex);
return ENOMEM;
}
newmsg->msg.data = data;
newmsg->msg.msgtype = msgtype;
newmsg->next = NULL;
if (queue->last == NULL) {
queue->last = newmsg;
queue->first = newmsg;
} else {
queue->last->next = newmsg;
queue->last = newmsg;
}
if(queue->length == 0)
pthread_cond_broadcast(&queue->cond);
queue->length++;
pthread_mutex_unlock(&queue->mutex);
return 0;
}
int thread_queue_get(struct threadqueue *queue, const struct timespec *timeout, struct threadmsg *msg)
{
struct msglist *firstrec;
int ret = 0;
struct timespec abstimeout;
if (queue == NULL || msg == NULL) {
return EINVAL;
}
if (timeout) {
struct timeval now;
gettimeofday(&now, NULL);
abstimeout.tv_sec = now.tv_sec + timeout->tv_sec;
abstimeout.tv_nsec = (now.tv_usec * 1000) + timeout->tv_nsec;
if (abstimeout.tv_nsec >= 1000000000) {
abstimeout.tv_sec++;
abstimeout.tv_nsec -= 1000000000;
}
}
pthread_mutex_lock(&queue->mutex);
/* Will wait until awakened by a signal or broadcast */
while (queue->first == NULL && ret != ETIMEDOUT) { //Need to loop to handle spurious wakeups
if (timeout) {
ret = pthread_cond_timedwait(&queue->cond, &queue->mutex, &abstimeout);
} else {
pthread_cond_wait(&queue->cond, &queue->mutex);
}
}
if (ret == ETIMEDOUT) {
pthread_mutex_unlock(&queue->mutex);
return ret;
}
firstrec = queue->first;
queue->first = queue->first->next;
queue->length--;
if (queue->first == NULL) {
queue->last = NULL; // we know this since we hold the lock
queue->length = 0;
}
msg->data = firstrec->msg.data;
msg->msgtype = firstrec->msg.msgtype;
msg->qlength = queue->length;
release_msglist(queue,firstrec);
pthread_mutex_unlock(&queue->mutex);
return 0;
}
//maybe caller should supply a callback for cleaning the elements ?
int thread_queue_cleanup(struct threadqueue *queue, int freedata)
{
struct msglist *rec;
struct msglist *next;
struct msglist *recs[2];
int ret,i;
if (queue == NULL) {
return EINVAL;
}
pthread_mutex_lock(&queue->mutex);
recs[0] = queue->first;
recs[1] = queue->msgpool;
for(i = 0; i < 2 ; i++) {
rec = recs[i];
while (rec) {
next = rec->next;
if (freedata) {
free(rec->msg.data);
}
free(rec);
rec = next;
}
}
pthread_mutex_unlock(&queue->mutex);
ret = pthread_mutex_destroy(&queue->mutex);
pthread_cond_destroy(&queue->cond);
return ret;
}
long thread_queue_length(struct threadqueue *queue)
{
long counter;
// get the length properly
pthread_mutex_lock(&queue->mutex);
counter = queue->length;
pthread_mutex_unlock(&queue->mutex);
return counter;
}

185
util/thread_queue.h Normal file
View File

@ -0,0 +1,185 @@
#ifndef _OAUTH2LIB_UTIL_THREAD_QUEUE_
#define _OAUTH2LIB_UTIL_THREAD_QUEUE_
// https://stackoverflow.com/questions/4577961/pthread-synchronized-blocking-queue
#include <pthread.h>
/**
* @defgroup ThreadQueue ThreadQueue
*
* Little API for waitable queues, typically used for passing messages
* between threads.
*
*/
/**
* @mainpage
*/
/**
* A thread message.
*
* @ingroup ThreadQueue
*
* This is used for passing to #thread_queue_get for retreive messages.
* the date is stored in the data member, the message type in the #msgtype.
*
* Typical:
* @code
* struct threadmsg;
* struct myfoo *foo;
* while(1)
* ret = thread_queue_get(&queue,NULL,&message);
* ..
* foo = msg.data;
* switch(msg.msgtype){
* ...
* }
* }
* @endcode
*
*/
struct threadmsg{
/**
* Holds the data.
*/
void *data;
/**
* Holds the messagetype
*/
long msgtype;
/**
* Holds the current queue lenght. Might not be meaningful if there's several readers
*/
long qlength;
};
/**
* A TthreadQueue
*
* @ingroup ThreadQueue
*
* You should threat this struct as opaque, never ever set/get any
* of the variables. You have been warned.
*/
struct threadqueue {
/**
* Length of the queue, never set this, never read this.
* Use #threadqueue_length to read it.
*/
long length;
/**
* Mutex for the queue, never touch.
*/
pthread_mutex_t mutex;
/**
* Condition variable for the queue, never touch.
*/
pthread_cond_t cond;
/**
* Internal pointers for the queue, never touch.
*/
struct msglist *first,*last;
/**
* Internal cache of msglists
*/
struct msglist *msgpool;
/**
* No. of elements in the msgpool
*/
long msgpool_length;
};
/**
* Initializes a queue.
*
* @ingroup ThreadQueue
*
* thread_queue_init initializes a new threadqueue. A new queue must always
* be initialized before it is used.
*
* @param queue Pointer to the queue that should be initialized
* @return 0 on success see pthread_mutex_init
*/
int thread_queue_init(struct threadqueue *queue);
/**
* Adds a message to a queue
*
* @ingroup ThreadQueue
*
* thread_queue_add adds a "message" to the specified queue, a message
* is just a pointer to a anything of the users choice. Nothing is copied
* so the user must keep track on (de)allocation of the data.
* A message type is also specified, it is not used for anything else than
* given back when a message is retreived from the queue.
*
* @param queue Pointer to the queue on where the message should be added.
* @param data the "message".
* @param msgtype a long specifying the message type, choice of the user.
* @return 0 on succes ENOMEM if out of memory EINVAL if queue is NULL
*/
int thread_queue_add(struct threadqueue *queue, void *data, long msgtype);
/**
* Gets a message from a queue
*
* @ingroup ThreadQueue
*
* thread_queue_get gets a message from the specified queue, it will block
* the caling thread untill a message arrives, or the (optional) timeout occurs.
* If timeout is NULL, there will be no timeout, and thread_queue_get will wait
* untill a message arrives.
*
* struct timespec is defined as:
* @code
* struct timespec {
* long tv_sec; // seconds
* long tv_nsec; // nanoseconds
* };
* @endcode
*
* @param queue Pointer to the queue to wait on for a message.
* @param timeout timeout on how long to wait on a message
* @param msg pointer that is filled in with mesagetype and data
*
* @return 0 on success EINVAL if queue is NULL ETIMEDOUT if timeout occurs
*/
int thread_queue_get(struct threadqueue *queue, const struct timespec *timeout, struct threadmsg *msg);
/**
* Gets the length of a queue
*
* @ingroup ThreadQueue
*
* threadqueue_length returns the number of messages waiting in the queue
*
* @param queue Pointer to the queue for which to get the length
* @return the length(number of pending messages) in the queue
*/
long thread_queue_length( struct threadqueue *queue );
/**
* @ingroup ThreadQueue
* Cleans up the queue.
*
* threadqueue_cleanup cleans up and destroys the queue.
* This will remove all messages from a queue, and reset it. If
* freedata is != 0 free(3) will be called on all pending messages in the queue
* You cannot call this if there are someone currently adding or getting messages
* from the queue.
* After a queue have been cleaned, it cannot be used again untill #thread_queue_init
* has been called on the queue.
*
* @param queue Pointer to the queue that should be cleaned
* @param freedata set to nonzero if free(3) should be called on remaining
* messages
* @return 0 on success EINVAL if queue is NULL EBUSY if someone is holding any locks on the queue
*/
int thread_queue_cleanup(struct threadqueue *queue, int freedata);
#endif //_OAUTH2LIB_UTIL_THREAD_QUEUE_