From cba77c00ba6741e1b406ec4baba76c33086122b8 Mon Sep 17 00:00:00 2001 From: Patrick Date: Wed, 6 Nov 2024 22:55:07 +0100 Subject: [PATCH] first --- CMakeLists.txt | 24 ++++ base64.c | 286 +++++++++++++++++++++++++++++++++++++ base64.h | 37 +++++ log.c | 176 +++++++++++++++++++++++ log.c.old | 97 +++++++++++++ log.h | 54 +++++++ log.h.old | 38 +++++ main.c | 23 +++ main.c.old | 131 +++++++++++++++++ networking.c | 221 ++++++++++++++++++++++++++++ networking.h | 11 ++ pkce.c | 145 +++++++++++++++++++ pkce.h | 6 + server.c | 141 ++++++++++++++++++ server.h | 20 +++ ssl.c | 183 ++++++++++++++++++++++++ ssl.h | 17 +++ util/a.out | Bin 0 -> 20968 bytes util/sorted_str_set.c | 174 ++++++++++++++++++++++ util/sorted_str_set.h | 69 +++++++++ util/test_sorted_str_set.c | 67 +++++++++ util/thread_queue.c | 204 ++++++++++++++++++++++++++ util/thread_queue.h | 185 ++++++++++++++++++++++++ 23 files changed, 2309 insertions(+) create mode 100644 CMakeLists.txt create mode 100644 base64.c create mode 100644 base64.h create mode 100644 log.c create mode 100644 log.c.old create mode 100644 log.h create mode 100644 log.h.old create mode 100644 main.c create mode 100644 main.c.old create mode 100644 networking.c create mode 100644 networking.h create mode 100644 pkce.c create mode 100644 pkce.h create mode 100644 server.c create mode 100644 server.h create mode 100644 ssl.c create mode 100644 ssl.h create mode 100755 util/a.out create mode 100644 util/sorted_str_set.c create mode 100644 util/sorted_str_set.h create mode 100644 util/test_sorted_str_set.c create mode 100644 util/thread_queue.c create mode 100644 util/thread_queue.h diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..94a88aa --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,24 @@ +cmake_minimum_required(VERSION 3.20) + +project(OAuth2Client C) +set(CMAKE_C_STANDARD 23) + +set(CMAKE_BUILD_TYPE "DEBUG") +set(CMAKE_EXPORT_COMPILE_COMMANDS on) + +find_package(OpenSSL REQUIRED) + +add_library(OAuth2Lib + util/thread_queue.c + util/sorted_str_set.c + networking.c + pkce.c + base64.c + ssl.c + log.c + server.c) +target_include_directories(OAuth2Lib PRIVATE ${OpenSSL_INCLUDE_DIRS}) +target_link_libraries(OAuth2Lib PRIVATE OpenSSL::SSL OpenSSL::Crypto) + +add_executable(main main.c) +target_link_libraries(main OAuth2Lib) diff --git a/base64.c b/base64.c new file mode 100644 index 0000000..ac8a99d --- /dev/null +++ b/base64.c @@ -0,0 +1,286 @@ +#include "base64.h" + +#include +#include + +static const char base64url_table[65] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; + +// http://web.mit.edu/freebsd/head/contrib/wpa/src/utils/base64.c +/* + * Base64 encoding/decoding (RFC1341) + * Copyright (c) 2005-2011, Jouni Malinen + * + * This software may be distributed under the terms of the BSD license. + * See README for more details. + */ +static const unsigned char base64_table[65] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +/** +* base64_encode - Base64 encode +* @src: Data to be encoded +* @len: Length of the data to be encoded +* @out_len: Pointer to output length variable, or %NULL if not used +* Returns: Allocated buffer of out_len bytes of encoded data, +* or %NULL on failure +* +* Caller is responsible for freeing the returned buffer. Returned buffer is +* nul terminated to make it easier to use as a C string. The nul terminator is +* not included in out_len. +*/ +char * base64_encode(const uint8_t *src, size_t len, + size_t *out_len) +{ + unsigned char *out, *pos; + const unsigned char *end, *in; + size_t olen; + int line_len; + + olen = len * 4 / 3 + 4; /* 3-byte blocks to 4-byte */ + olen += olen / 72; /* line feeds */ + olen++; /* nul termination */ + if (olen < len) + return NULL; /* integer overflow */ + out = malloc(olen); + if (out == NULL) + return NULL; + + end = src + len; + in = src; + pos = out; + line_len = 0; + while (end - in >= 3) { + *pos++ = base64_table[in[0] >> 2]; + *pos++ = base64_table[((in[0] & 0x03) << 4) | (in[1] >> 4)]; + *pos++ = base64_table[((in[1] & 0x0f) << 2) | (in[2] >> 6)]; + *pos++ = base64_table[in[2] & 0x3f]; + in += 3; + line_len += 4; + if (line_len >= 72) { + *pos++ = '\n'; + line_len = 0; + } + } + + if (end - in) { + *pos++ = base64_table[in[0] >> 2]; + if (end - in == 1) { + *pos++ = base64_table[(in[0] & 0x03) << 4]; + *pos++ = '='; + } else { + *pos++ = base64_table[((in[0] & 0x03) << 4) | + (in[1] >> 4)]; + *pos++ = base64_table[(in[1] & 0x0f) << 2]; + } + *pos++ = '='; + line_len += 4; + } + + if (line_len) + *pos++ = '\n'; + + *pos = '\0'; + if (out_len) + *out_len = pos - out; + return out; +} + + +/** + * base64_decode - Base64 decode + * @src: Data to be decoded + * @len: Length of the data to be decoded + * @out_len: Pointer to output length variable + * Returns: Allocated buffer of out_len bytes of decoded data, + * or %NULL on failure + * + * Caller is responsible for freeing the returned buffer. + */ +uint8_t* base64_decode(const char *src, size_t len, + size_t *out_len) +{ + unsigned char dtable[256], *out, *pos, block[4], tmp; + size_t i, count, olen; + int pad = 0; + + memset(dtable, 0x80, 256); + for (i = 0; i < sizeof(base64_table) - 1; i++) + dtable[base64_table[i]] = (unsigned char) i; + dtable['='] = 0; + + count = 0; + for (i = 0; i < len; i++) { + if (dtable[src[i]] != 0x80) + count++; + } + + if (count == 0 || count % 4) + return NULL; + + olen = count / 4 * 3; + pos = out = malloc(olen); + if (out == NULL) + return NULL; + + count = 0; + for (i = 0; i < len; i++) { + tmp = dtable[src[i]]; + if (tmp == 0x80) + continue; + + if (src[i] == '=') + pad++; + block[count] = tmp; + count++; + if (count == 4) { + *pos++ = (block[0] << 2) | (block[1] >> 4); + *pos++ = (block[1] << 4) | (block[2] >> 2); + *pos++ = (block[2] << 6) | block[3]; + count = 0; + if (pad) { + if (pad == 1) + pos--; + else if (pad == 2) + pos -= 2; + else { + /* Invalid padding */ + free(out); + return NULL; + } + break; + } + } + } + + *out_len = pos - out; + return out; +} + +char * base64url_encode(const uint8_t *src, size_t len, + size_t *out_len) +{ + unsigned char *out, *pos; + const unsigned char *end, *in; + size_t olen; + int line_len; + + olen = len * 4 / 3 + 4; /* 3-byte blocks to 4-byte */ + olen += olen / 72; /* line feeds */ + olen++; /* nul termination */ + if (olen < len) + return NULL; /* integer overflow */ + out = malloc(olen); + if (out == NULL) + return NULL; + + end = src + len; + in = src; + pos = out; + line_len = 0; + while (end - in >= 3) { + *pos++ = base64url_table[in[0] >> 2]; + *pos++ = base64url_table[((in[0] & 0x03) << 4) | (in[1] >> 4)]; + *pos++ = base64url_table[((in[1] & 0x0f) << 2) | (in[2] >> 6)]; + *pos++ = base64url_table[in[2] & 0x3f]; + in += 3; + line_len += 4; + if (line_len >= 72) { + *pos++ = '\n'; + line_len = 0; + } + } + + if (end - in) { + *pos++ = base64url_table[in[0] >> 2]; + if (end - in == 1) { + *pos++ = base64url_table[(in[0] & 0x03) << 4]; + *pos++ = '='; + } else { + *pos++ = base64url_table[((in[0] & 0x03) << 4) | + (in[1] >> 4)]; + *pos++ = base64url_table[(in[1] & 0x0f) << 2]; + } + *pos++ = '='; + line_len += 4; + } + + if (line_len) + *pos++ = '\n'; + + *pos = '\0'; + if (out_len) + *out_len = pos - out; + return out; +} + + +/** + * base64_decode - Base64 decode + * @src: Data to be decoded + * @len: Length of the data to be decoded + * @out_len: Pointer to output length variable + * Returns: Allocated buffer of out_len bytes of decoded data, + * or %NULL on failure + * + * Caller is responsible for freeing the returned buffer. + */ +uint8_t* base64url_decode(const char *src, size_t len, + size_t *out_len) +{ + unsigned char dtable[256], *out, *pos, block[4], tmp; + size_t i, count, olen; + int pad = 0; + + memset(dtable, 0x80, 256); + for (i = 0; i < sizeof(base64url_table) - 1; i++) + dtable[base64url_table[i]] = (unsigned char) i; + dtable['='] = 0; + + count = 0; + for (i = 0; i < len; i++) { + if (dtable[src[i]] != 0x80) + count++; + } + + if (count == 0 || count % 4) + return NULL; + + olen = count / 4 * 3; + pos = out = malloc(olen); + if (out == NULL) + return NULL; + + count = 0; + for (i = 0; i < len; i++) { + tmp = dtable[src[i]]; + if (tmp == 0x80) + continue; + + if (src[i] == '=') + pad++; + block[count] = tmp; + count++; + if (count == 4) { + *pos++ = (block[0] << 2) | (block[1] >> 4); + *pos++ = (block[1] << 4) | (block[2] >> 2); + *pos++ = (block[2] << 6) | block[3]; + count = 0; + if (pad) { + if (pad == 1) + pos--; + else if (pad == 2) + pos -= 2; + else { + /* Invalid padding */ + free(out); + return NULL; + } + break; + } + } + } + + *out_len = pos - out; + return out; +} diff --git a/base64.h b/base64.h new file mode 100644 index 0000000..82d2b41 --- /dev/null +++ b/base64.h @@ -0,0 +1,37 @@ +#ifndef _OSAUTH2LIB_BASE64_ +#define _OSAUTH2LIB_BASE64_ + +#include +#include + +/** +* base64_encode - Base64 encode +* @src: Data to be encoded +* @len: Length of the data to be encoded +* @out_len: Pointer to output length variable, or %NULL if not used +* Returns: Allocated buffer of out_len bytes of encoded data, +* or %NULL on failure +* +* Caller is responsible for freeing the returned buffer. Returned buffer is +* nul terminated to make it easier to use as a C string. The nul terminator is +* not included in out_len. +*/ +char *base64_encode(const uint8_t* src, size_t len, size_t* out_len); + +/** + * base64_decode - Base64 decode + * @src: Data to be decoded + * @len: Length of the data to be decoded + * @out_len: Pointer to output length variable + * Returns: Allocated buffer of out_len bytes of decoded data, + * or %NULL on failure + * + * Caller is responsible for freeing the returned buffer. + */ +uint8_t *base64_decode(const char* src, size_t len, size_t* out_len); + + +char *base64url_encode(const uint8_t* src, size_t len, size_t* out_len); +uint8_t *base64url_decode(const char* src, size_t len, size_t* out_len); + +#endif //_OSAUTH2LIB_BASE64_ diff --git a/log.c b/log.c new file mode 100644 index 0000000..156bf3e --- /dev/null +++ b/log.c @@ -0,0 +1,176 @@ +#include "log.h" + +#include +#include +#include + +#include + +#include "util/sorted_str_set.h" + +static const char* LOG_LEVEL_STRING_TABLE[] = { + "ERROR", + "WARN", + "INFO", + "DEBUG", + "TRACE" +}; + +// Maybe add __attribute__((constructor)) in the future if available to skip _config_initialized +static bool _user_default_config_initialized = false; +static struct log_config_s _user_default_config; + +static sorted_str_set* _module_configs = NULL; + +static inline struct log_config_s _default_config() { + + if (_user_default_config_initialized) { + return _user_default_config; + } + + return (struct log_config_s) { + .loc_error = LOG_LOC_STDS(stderr), + .loc_warn = LOG_LOC_STDS(stderr), + .loc_info = LOG_LOC_STDS(stdout), + .loc_debug = LOG_LOC_NONE(), + .loc_trace = LOG_LOC_NONE(), + + .timestamp = true, + }; +} + +void log_set_default_config(struct log_config_s config) { + _user_default_config = config; + _user_default_config_initialized = true; +} + +struct log_config_s* log_get_config(const char* module_name) { + + assert(module_name != NULL); + + if (!_module_configs) { + if (!(_module_configs = sorted_str_set_new(sizeof(struct log_config_s)))) { + return NULL; + } + } + + struct log_config_s default_conf = _default_config(); + return sorted_str_set_insert(_module_configs, module_name, &default_conf); +} + +void log_log(const char* module, enum LOG_LEVEL log_level, const char* fmt, ...) { + + assert(module != NULL); + + struct log_config_s* config = log_get_config(module); + assert(config); + + struct log_loc* dest_loc; + switch (log_level) { + case LOG_ERR: dest_loc = &config->loc_error; break; + case LOG_WARN: dest_loc = &config->loc_warn; break; + case LOG_INFO: dest_loc = &config->loc_info; break; + case LOG_DEBUG: dest_loc = &config->loc_debug; break; + case LOG_TRACE: dest_loc = &config->loc_trace; break; + default: return; + } + + if (dest_loc->type == LOG_LOC_TYPE_DISABLED) { + return; + } else if (dest_loc->type == LOG_LOC_TYPE_FS) { + + if (!dest_loc->location) { + if (!(dest_loc->location = fopen(dest_loc->file_name, "a"))) { + return; + } + } + + } else if (dest_loc->type != LOG_LOC_TYPE_STDS) { + return; + } + + FILE* dest = dest_loc->location; + + va_list varargs; + va_start(varargs, fmt); + + if (config->timestamp) { + struct timespec utc_time; + struct tm local_time; + char time_str_buf[10+1+8+1]; + timespec_get(&utc_time, TIME_UTC); + localtime_r(&(utc_time.tv_sec), &local_time); + + if(strftime(time_str_buf, sizeof(time_str_buf), "%F %T", &local_time)) { + fprintf(dest, "[%s.%9lu] ", time_str_buf, utc_time.tv_nsec); + } + } + + fprintf(dest, "[%s] %s: ", LOG_LEVEL_STRING_TABLE[log_level], module); + vfprintf(dest, fmt, varargs); + fprintf(dest, "\n"); + + if (dest_loc->type == LOG_LOC_TYPE_FS && !dest_loc->keep_file_open) { + fclose(dest_loc->location); + dest_loc->location = NULL; + } +} + +void log_log_ssl_err(const char* module, enum LOG_LEVEL log_level, const char* fmt, ...) { + assert(module != NULL); + + struct log_config_s* config = log_get_config(module); + assert(config); + + struct log_loc* dest_loc; + switch (log_level) { + case LOG_ERR: dest_loc = &config->loc_error; break; + case LOG_WARN: dest_loc = &config->loc_warn; break; + case LOG_INFO: dest_loc = &config->loc_info; break; + case LOG_DEBUG: dest_loc = &config->loc_debug; break; + case LOG_TRACE: dest_loc = &config->loc_trace; break; + default: return; + } + + if (dest_loc->type == LOG_LOC_TYPE_DISABLED) { + return; + } else if (dest_loc->type == LOG_LOC_TYPE_FS) { + + if (!dest_loc->location) { + if (!(dest_loc->location = fopen(dest_loc->file_name, "a"))) { + return; + } + } + + } else if (dest_loc->type != LOG_LOC_TYPE_STDS) { + return; + } + + FILE* dest = dest_loc->location; + + va_list varargs; + va_start(varargs, fmt); + + if (config->timestamp) { + struct timespec utc_time; + struct tm local_time; + char time_str_buf[10+1+8+1]; + timespec_get(&utc_time, TIME_UTC); + localtime_r(&(utc_time.tv_sec), &local_time); + + if(strftime(time_str_buf, sizeof(time_str_buf), "%F %T", &local_time)) { + fprintf(dest, "[%s.%9lu] ", time_str_buf, utc_time.tv_nsec); + } + } + + fprintf(dest, "[%s] %s: ", LOG_LEVEL_STRING_TABLE[log_level], module); + vfprintf(dest, fmt, varargs); + fprintf(dest, "\n"); + ERR_print_errors_fp(dest); + + if (dest_loc->type == LOG_LOC_TYPE_FS && !dest_loc->keep_file_open) { + fclose(dest_loc->location); + dest_loc->location = NULL; + } + +} diff --git a/log.c.old b/log.c.old new file mode 100644 index 0000000..bb69cc9 --- /dev/null +++ b/log.c.old @@ -0,0 +1,97 @@ +#include "log.h" + +#include +#include +#include + +#include + +static const char* LOG_LEVEL_STRING_TABLE[] = { + "ERROR", + "WARN", + "INFO", + "DEBUG", + "TRACE" +}; + +// Maybe add __attribute__((constructor)) in the future if available to skip _config_initialized +static bool _user_default_config_initialized = false; +static struct log_config_s _user_default_config; + +static inline struct log_config_s _default_config() { + + if (_user_default_config_initialized) { + return _user_default_config; + } + + return (struct log_config_s) { + .loc_error = stderr, + .loc_warn = stderr, + .loc_info = stdout, + .loc_debug = NULL, + .loc_trace = NULL, + + .timestamp = true, + }; +} + +void log_set_default_config(struct log_config_s config) { + _user_default_config = config; + _user_default_config_initialized = true; +} + +struct logger_s log_new(const char* module_name) { + return (struct logger_s) { + .module_name = module_name, + .config = _default_config(), + }; +} + +void log_log(const struct logger_s* logger, enum LOG_LEVEL log_level, const char* fmt, ...) { + + assert(logger != NULL); + + FILE* dest = NULL; + switch (log_level) { + case LOG_ERR: dest = logger->config.loc_error; break; + case LOG_WARN: dest = logger->config.loc_warn; break; + case LOG_INFO: dest = logger->config.loc_info; break; + case LOG_DEBUG: dest = logger->config.loc_debug; break; + case LOG_TRACE: dest = logger->config.loc_trace; break; + default: break; + } + + if (dest == NULL) { + return; + } + + va_list varargs; + va_start(varargs, fmt); + + if (logger->config.timestamp) { + struct timespec utc_time; + struct tm local_time; + char time_str_buf[10+1+8+1]; + timespec_get(&utc_time, TIME_UTC); + localtime_r(&(utc_time.tv_sec), &local_time); + + if(strftime(time_str_buf, sizeof(time_str_buf), "%F %T", &local_time)) { + fprintf(dest, "[%s.%9lu] ", time_str_buf, utc_time.tv_nsec); + } + } + + fprintf(dest, "[%s] %s: ", LOG_LEVEL_STRING_TABLE[log_level], logger->module_name); + vfprintf(dest, fmt, varargs); + fprintf(dest, "\n"); +} + +void log_log_ssl_err(const struct logger_s* logger, enum LOG_LEVEL log_level, unsigned long error, const char* message_fmt, ...) { + char error_buf[256]; + ERR_error_string_n(ERR_get_error(), error_buf, sizeof(error_buf)); + + va_list varargs; + va_start(varargs, message_fmt); + + log_log(logger, log_level, message_fmt, varargs); + log_log(logger, log_level, "%s", error_buf); +} diff --git a/log.h b/log.h new file mode 100644 index 0000000..f986089 --- /dev/null +++ b/log.h @@ -0,0 +1,54 @@ +#ifndef _OAUTH2LIB_LOG_ +#define _OAUTH2LIB_LOG_ + +#include +#include + +enum LOG_LEVEL { + LOG_ERR, + LOG_WARN, + LOG_INFO, + LOG_DEBUG, + LOG_TRACE, +}; + +enum LOG_LOC_TYPE { + LOG_LOC_TYPE_DISABLED, + LOG_LOC_TYPE_FS, + LOG_LOC_TYPE_STDS +}; + +struct log_loc { + enum LOG_LOC_TYPE type; + FILE* location; + const char* file_name; + bool keep_file_open; +}; + +#define LOG_LOC_FILE(file_location, keep_open) ((struct log_loc) { .type=LOG_LOC_TYPE_FS, .location=NULL, .file_name=file_location, .keep_file_open=keep_open }) +#define LOG_LOC_STDS(stream) ((struct log_loc) { .type=LOG_LOC_TYPE_STDS, .location=stream, .file_name=NULL }) +#define LOG_LOC_NONE() ((struct log_loc) { .type=LOG_LOC_TYPE_DISABLED }) + +struct log_config_s { + struct log_loc loc_error; + struct log_loc loc_warn; + struct log_loc loc_info; + struct log_loc loc_debug; + struct log_loc loc_trace; + + bool timestamp; +}; + +struct logger_s { + const char* module_name; + struct log_config_s config; +}; + +void log_set_default_config(struct log_config_s config); +struct log_config_s* log_get_config(const char* module_name); + +void log_log(const char* module, enum LOG_LEVEL log_level, const char* fmt, ...); + +void log_log_ssl_err(const char* module, enum LOG_LEVEL log_level, const char* message_fmt, ...); + +#endif //_OAUTH2LIB_LOG_ diff --git a/log.h.old b/log.h.old new file mode 100644 index 0000000..f34cba2 --- /dev/null +++ b/log.h.old @@ -0,0 +1,38 @@ +#ifndef _OAUTH2LIB_LOG_ +#define _OAUTH2LIB_LOG_ + +#include +#include +#include + +enum LOG_LEVEL { + LOG_ERR, + LOG_WARN, + LOG_INFO, + LOG_DEBUG, + LOG_TRACE, +}; + +struct log_config_s { + FILE* loc_error; + FILE* loc_warn; + FILE* loc_info; + FILE* loc_debug; + FILE* loc_trace; + + bool timestamp; +}; + +struct logger_s { + const char* module_name; + struct log_config_s config; +}; + +void log_set_default_config(struct log_config_s config); +struct logger_s log_new(const char* module_name); + +void log_log(const struct logger_s* logger, enum LOG_LEVEL log_level, const char* fmt, ...); + +void log_log_ssl_err(const struct logger_s* logger, enum LOG_LEVEL log_level, unsigned long error, const char* message_fmt, ...); + +#endif //_OAUTH2LIB_LOG_ diff --git a/main.c b/main.c new file mode 100644 index 0000000..004aedc --- /dev/null +++ b/main.c @@ -0,0 +1,23 @@ +#include + +#include "networking.h" +#include "pkce.h" +#include "log.h" +#include "server.h" +#include "ssl.h" + +#include + +int main(int argc, const char* argv[]) { + + /*SSL_CTX* ctx = SSL_CTX_new(TLS_server_method()); + ssl_use_ppcerts(ctx);*/ + + // tcp_get(argv[1]); + + //do_pkce(); + + server_start(); + + return 0; +} diff --git a/main.c.old b/main.c.old new file mode 100644 index 0000000..d402245 --- /dev/null +++ b/main.c.old @@ -0,0 +1,131 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//#include "dns.h" + +#define WEB_ADDR "www.onedrive.com" + +struct socket_data_chunk { + char data[4096]; + struct socket_data_chunk *next; +}; + +void free_socketdata(struct socket_data_chunk* data_chunk) { + for (struct socket_data_chunk *iter = data_chunk; iter != NULL; ) { + struct socket_data_chunk *data = iter; + iter = iter->next; + + free(data); + } +} + +int strncmp_end(const char *str, size_t max_strlen, const char *suffix, size_t max_suffix_len) +{ + if (!str || !suffix) + return 0; + size_t lenstr = strnlen(str, max_strlen); + size_t lensuffix = strnlen(suffix, max_suffix_len); + if (lensuffix > lenstr) + return 0; + return strncmp(str + lenstr - lensuffix, suffix, lensuffix); +} + + +int main() { + + //oauth2_dns_resolve("onedrive.com"); + + struct addrinfo resolve_hints = { + .ai_family = AF_INET, //AF_UNSPEC, // IPv4 or IPv6 + .ai_socktype = SOCK_STREAM, + .ai_flags = 0, + .ai_protocol = 0 + }; + + struct addrinfo *resolve_result; + + int get_addr_res = getaddrinfo(WEB_ADDR, NULL, &resolve_hints, &resolve_result); + if (get_addr_res != 0) { + printf("getaddrinfo failed: %d\n", get_addr_res); + exit(-1); + } + char ip_addr[50]; + printf("IP address: %s\n", inet_ntop(AF_INET, &((struct sockaddr_in*) resolve_result->ai_addr)->sin_addr, ip_addr, sizeof(ip_addr))); + printf("Port: %d\n", ntohs(((struct sockaddr_in*)resolve_result->ai_addr)->sin_port)); + + ((struct sockaddr_in*)resolve_result->ai_addr)->sin_port = htons(80); + + int sock_desc = socket(AF_INET, SOCK_STREAM, 0); + if (sock_desc == -1) { + printf("Socket creation failed. Errno: %d\n", errno); + exit(-1); + } + + printf("created socket\n"); + + struct sockaddr_in server_addr = *(struct sockaddr_in *)resolve_result->ai_addr; + +// struct sockaddr_in server_addr = { +// .sin_family = AF_INET, +// .sin_addr = inet_addr("127.0.0.1"), +// .sin_port = htons(8080) +// }; + + if (connect(sock_desc, (struct sockaddr*)&server_addr, sizeof(server_addr)) != 0) { + printf("Socket connection failed. Errno: %d\n", errno); + exit(-1); + close(sock_desc); + } + + printf("Connected to address\n"); + + const char *msg = + "GET / HTTP/1.1\r\n" + "Host: " WEB_ADDR "\r\n" + "\r\n"; + write(sock_desc, msg, strlen(msg)); + + printf("sent: \n%s\n", msg); + + struct socket_data_chunk *data = NULL; + struct socket_data_chunk **it = &data; + do { + struct socket_data_chunk *nit = calloc(1, sizeof(struct socket_data_chunk)); + if (!nit) { + printf("failed to allocate data for receive\n"); + exit(-1); + } + + if (*it) { + (*it)->next = nit; + } else { + *it = nit; + } + it = &nit; + + ssize_t bytes_read = recv(sock_desc, &nit->data, sizeof(nit->data), 0); + if (bytes_read == -1) { + printf("Failed to recieve data. Errno: %d", errno); + exit(-1); + } + + printf("received (%ld) bytes:\n%s\n", bytes_read, nit->data); + + //} while (!strstr((*it)->data, "\r\n\r\n")); + } while(strncmp_end(data->data, sizeof((*it)->data), "\r\n\r\n", 5)); + + printf("Finishing\n"); + + close(sock_desc); + freeaddrinfo(resolve_result); + free_socketdata(data); + + return 0; +} diff --git a/networking.c b/networking.c new file mode 100644 index 0000000..38c9fcd --- /dev/null +++ b/networking.c @@ -0,0 +1,221 @@ +#include "networking.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include "ssl.h" + +typedef int socket_fd; + +static SSL_CTX *ssl_ctx = NULL; + +char *strstr_fixn(const char *haystack, const char *needle, size_t len) { + size_t needle_len; + + if (0 == (needle_len = strnlen(needle, len))) + return (char *)haystack; + + for (size_t i = 0; i <= len - needle_len; i++) + { + if ((haystack[0] == needle[0]) && + (0 == strncmp(haystack, needle, needle_len))) + return (char *)haystack; + + haystack++; + } + return NULL; +} + + + +static socket_fd connect_to_http(const char* hostname) { + + struct addrinfo resolve_hints = { + .ai_family = AF_UNSPEC, + .ai_socktype = SOCK_STREAM, + }; + + struct addrinfo* resolve_result; + int result = getaddrinfo(hostname, "https", &resolve_hints, &resolve_result); + + if (result != 0) { + fprintf(stderr, "Failed to get IP address of \"%s\" for an http connection. Error: %s\n", hostname, gai_strerror(result)); + return -1; + } + + socket_fd sock_desc = -1; + for (struct addrinfo *it = resolve_result; it != NULL; it = it->ai_next) { + + sock_desc = socket(it->ai_family, it->ai_socktype, it->ai_protocol); + if (sock_desc == -1) { + continue; + } + + if (connect(sock_desc, it->ai_addr, it->ai_addrlen) != -1) { + uint16_t port = ntohs( + (it->ai_family == AF_INET) + ? ((struct sockaddr_in*) it->ai_addr)->sin_port + : ((struct sockaddr_in6*) it->ai_addr)->sin6_port); + printf("Connected to \"%s\" on port %d\n", hostname, port); + break; + } + close(sock_desc); + } + + freeaddrinfo(resolve_result); + + return sock_desc; +} +void tcp_get(const char* addr) { + + initialize_ssl(); + printf("initialized ssl\n"); + + + socket_fd sock_fd = connect_to_http(addr); + + printf("Got socket: %d\n", sock_fd); + + SSL* ssl_conn = SSL_new(ssl_ctx); + if (!ssl_conn) { + fprintf(stderr, "Failed to create an SSL connection provider.\n"); + ERR_print_errors_fp(stderr); + + close(sock_fd); + return; + } + + SSL_set_fd(ssl_conn, sock_fd); + + if (!SSL_set_tlsext_host_name(ssl_conn, addr)) { + fprintf(stderr, "Failed to set SNI hostname.\n"); + ERR_print_errors_fp(stderr); + + SSL_free(ssl_conn); + return; + } + + if (!SSL_set1_host(ssl_conn, addr)) { + fprintf(stderr, "Failed to set certificate hostname.\n"); + ERR_print_errors_fp(stderr); + + SSL_free(ssl_conn); + return; + } + + if (SSL_connect(ssl_conn) < 1) { + fprintf(stderr, "Failed to connect to server.\n"); + + if (SSL_get_verify_result(ssl_conn) != X509_V_OK) { + fprintf(stderr, "Verification error: %s\n", X509_verify_cert_error_string(SSL_get_verify_result(ssl_conn))); + } else { + ERR_print_errors_fp(stderr); + } + + SSL_free(ssl_conn); + return; + } + + printf("Connected to peer.\n"); + + const char* HTTP_HEAD = "GET / HTTP/1.1\r\n" + "Connection: close\r\n" + "Host: "; + const char* HTTP_HEAD_END = "\r\n\r\n"; + + const char* msg[] = { HTTP_HEAD, addr, HTTP_HEAD_END }; + + for (size_t i = 0; i < sizeof(msg)/sizeof(msg[0]); ++i) { + + if (!SSL_write(ssl_conn, msg[i], strlen(msg[i]))) { + fprintf(stderr, "Failed to send message: \"%s\"\n", msg[i]); + ERR_print_errors_fp(stderr); + + SSL_free(ssl_conn); + return; + } + } + + printf("Sent HTTP request to peer.\n"); + + size_t read_bytes; + char read_buf[160]; + + enum { HTTP_CONTENT_LENGTH, HTTP_CONTENT_CHUNKED } http_content_transmit_type; + size_t total_read_bytes; + size_t expected_bytes = -1; + bool body = false; + + char last_3_of_prev[3] = {}; + + printf("\n"); + while (SSL_read_ex(ssl_conn, read_buf, sizeof(read_buf), &read_bytes)) { + total_read_bytes += read_bytes; + + if ((http_content_transmit_type == HTTP_CONTENT_CHUNKED && read_buf[0] == '0') + || http_content_transmit_type == HTTP_CONTENT_LENGTH && total_read_bytes == expected_bytes) { + SSL_set_options(ssl_conn, SSL_OP_IGNORE_UNEXPECTED_EOF); + } + + if (!body) { + const char* content_length_str = "Content-Length: "; + const char* transfer_encoding_str = "Transfer-Encoding: chunked"; + + char* content_length_start = strstr_fixn(read_buf, content_length_str, read_bytes); + if (content_length_start != NULL) { + http_content_transmit_type = HTTP_CONTENT_LENGTH; + + content_length_start += strlen(content_length_str); + expected_bytes = strtol(content_length_start, NULL, 10); + } + + if (strstr_fixn(read_buf, transfer_encoding_str, read_bytes)) { + http_content_transmit_type = HTTP_CONTENT_CHUNKED; + } + + char concat[] = { last_3_of_prev[0], last_3_of_prev[1], last_3_of_prev[2], read_buf[0], read_buf[1], read_buf[2], 0 }; + char *concat_start; + if ((concat_start = strstr(concat, HTTP_HEAD_END))) { + body = true; + total_read_bytes = read_bytes - (concat_start - concat + 1); + } else { + char* body_start = strstr_fixn(read_buf, HTTP_HEAD_END, read_bytes); + if (body_start) { + body = true; + total_read_bytes = read_bytes - (body_start + strlen(HTTP_HEAD_END) - read_buf); + } + } + } + + + fwrite(read_buf, 1, read_bytes, stdout); + } + printf("\n"); + + if (SSL_get_error(ssl_conn, 0) != SSL_ERROR_ZERO_RETURN) { + fprintf(stderr, "Failed to read response.\n"); + ERR_print_errors_fp(stderr); + } + + printf("Recieved %ld bytes", total_read_bytes); + printf(expected_bytes != -1 ? ", expected: %ld bytes\n" : "\n", expected_bytes); + printf("Ignore unexpected eof? %s\n", SSL_get_options(ssl_conn) & SSL_OP_IGNORE_UNEXPECTED_EOF ? "true" : "false"); + + if (SSL_shutdown(ssl_conn) < 1) { + fprintf(stderr, "Failed to shut down connection gracefully.\n"); + ERR_print_errors_fp(stderr); + } + + SSL_free(ssl_conn); + + printf("end\n"); +} diff --git a/networking.h b/networking.h new file mode 100644 index 0000000..caf3254 --- /dev/null +++ b/networking.h @@ -0,0 +1,11 @@ +#ifndef _OAUTH2LIB_NETWORKING_ +#define _OAUTH2LIB_NETWORKING_ + +struct content { + +}; + +void tcp_get(const char* addr); + + +#endif //_OAUTH2LIB_NETWORKING_ diff --git a/pkce.c b/pkce.c new file mode 100644 index 0000000..4f44e45 --- /dev/null +++ b/pkce.c @@ -0,0 +1,145 @@ +#include "pkce.h" + +#include +#include + +#include +#include +#include + +#include "base64.h" +#include "log.h" +#include "ssl.h" + +#define MODULE_NAME "PKCE" + +#define CODE_VERIFIER_LENGTH 32 // the length of the binary random string before base64url encoding + +struct code_verifier_s { + size_t length; + char* key; +}; + + + +static inline void free_code_verifier(struct code_verifier_s* ptr) { free(ptr->key); } + +static struct code_verifier_s generate_code_verifier() { + + uint8_t code_buffer[CODE_VERIFIER_LENGTH]; + + int gen_result = RAND_priv_bytes(code_buffer, sizeof(code_buffer)); + if (gen_result != 1) { + log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to generate code_verifier.\n"); + + return (struct code_verifier_s) { 0, NULL }; + } + + size_t encoded_size = 0; + char* encoded = base64url_encode(code_buffer, sizeof(code_buffer), &encoded_size); + if (!encoded) { + log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to encode code_verifier.\n"); + return (struct code_verifier_s) { 0, NULL }; + } + + // trim trailing =\n + encoded_size -= 2; + encoded[encoded_size] = '\0'; + + return (struct code_verifier_s) { encoded_size, encoded }; +} + +static uint8_t* sha256_encode(const char* str, size_t str_len) { + + EVP_MD_CTX* hash_ctx = EVP_MD_CTX_new(); + if (!hash_ctx) { + log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to create Crypto context.\n"); + + return NULL; + } + + EVP_MD* sha256_ctx = ssl_get_sha256(); + if (!sha256_ctx) { + EVP_MD_CTX_free(hash_ctx); + return NULL; + } + + if (!EVP_DigestInit_ex(hash_ctx, sha256_ctx, NULL)) { + log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to initialize sha256 digest.\n"); + + EVP_MD_CTX_free(hash_ctx); + return NULL; + } + + if (!EVP_DigestUpdate(hash_ctx, str, str_len)) { + log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to set message to digest.\n"); + + EVP_MD_CTX_free(hash_ctx); + return NULL; + } + + uint8_t *hash_buf = malloc(EVP_MD_get_size(sha256_ctx)); + if (!hash_buf) { + log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to allocate memory for digest output.\n"); + + EVP_MD_CTX_free(hash_ctx); + return NULL; + } + + unsigned int len = 0; + if (!EVP_DigestFinal_ex(hash_ctx, hash_buf, &len)) { + log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to compute hash.\n"); + + EVP_MD_CTX_free(hash_ctx); + return NULL; + } + + EVP_MD_CTX_free(hash_ctx); + + return hash_buf; +} + +char* _construct_get(char* server_url, char* client_id, char* redirect_uri, char* scope, char* state) { + + static char msg_template[] = "GET %s?response_type=code&client_id=%s&state=%s&redirect_uri=%s HTTP/1.1\r\nHost: localhost\r\n\r\n"; + + size_t len = strlen(msg_template) - 4*2 + strlen(server_url)+strlen(client_id)+strlen(redirect_uri)+strlen(scope)+strlen(state) + 1; + + char* buf = malloc(len); + if (!buf) { return NULL; } + + if (snprintf(buf, len, msg_template, server_url, client_id, state, redirect_uri) < 0) { + free(buf); + return NULL; + } + + return buf; +} + +void do_pkce() { + + log_log("PKCE", LOG_INFO, "generating code verifier"); + + struct code_verifier_s code_verifier = generate_code_verifier(); + + log_log("PKCE", LOG_DEBUG, "generated verifer: %s\n", code_verifier.key); + log_log("PKCE", LOG_INFO, "generating code challenge"); + + uint8_t* code_challenge_sha256 = sha256_encode(code_verifier.key, code_verifier.length); + if (!code_challenge_sha256) { + log_log("PKCE", LOG_ERR, "Failed to encode code challenge"); + goto pkce_cleanup; + } + + size_t code_challenge_length = 0; + char* code_challenge = base64url_encode(code_challenge_sha256, EVP_MD_get_size(ssl_get_sha256()), &code_challenge_length); + + code_challenge_length -= 2; + code_challenge[code_challenge_length] = '\0'; + + log_log("PKCE", LOG_DEBUG, "generated code challenge: %s\n", code_challenge); + +pkce_cleanup: + free_code_verifier(&code_verifier); + free(code_challenge); +} diff --git a/pkce.h b/pkce.h new file mode 100644 index 0000000..037d9b1 --- /dev/null +++ b/pkce.h @@ -0,0 +1,6 @@ +#ifndef _OAUTH2LIB_PKCE_ +#define _OAUTH2LIB_PKCE_ + +void do_pkce(); + +#endif //_OAUTH2LIB_PKCE_ diff --git a/server.c b/server.c new file mode 100644 index 0000000..9b1ce92 --- /dev/null +++ b/server.c @@ -0,0 +1,141 @@ +#include "server.h" + +#include +#include +#include +#include + +#include + +#include "log.h" +#include "ssl.h" + +#include "util/thread_queue.h" + +#define MODULE_NAME "login_server" + +#define SERVER_PORT 443 + +static const char message[] = "

Hello World!

"; + +static SSL_CTX* _ssl_server_ctx = NULL; +static struct logger_s _logger; + +static pthread_t *_server_thread = NULL; +static struct threadqueue _server_thread_queue; + +_Atomic enum server_state _server_state = SERVER_NOT_STARTED; + +static enum server_start_return _server_thread_func(void* data) { + + atomic_store(&_server_state, SERVER_STARTING); + + struct sockaddr_in sockaddr = { + .sin_family = AF_INET, + .sin_port = htons(SERVER_PORT), + .sin_addr = { + .s_addr = htonl(INADDR_ANY) + } + }; + + int sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) { + log_log(MODULE_NAME, LOG_ERR, "Failed to create socket: %s", strerror(errno)); + atomic_store(&_server_state, SERVER_FAILED); + return SERVER_FAIL; + } + + if (bind(sock, (struct sockaddr*)&sockaddr, sizeof(sockaddr)) < 0) { + log_log(MODULE_NAME, LOG_ERR, "Failed to bind socket: %s", strerror(errno)); + atomic_store(&_server_state, SERVER_FAILED); + close(sock); + return SERVER_FAIL; + } + + if (listen(sock, 1) != 0) { + log_log(MODULE_NAME, LOG_ERR, "Failed to start listening on socket: %s", strerror(errno)); + atomic_store(&_server_state, SERVER_FAILED); + close(sock); + return SERVER_FAIL; + } + + log_log(MODULE_NAME, LOG_INFO, "Waiting for incoming connections..."); + atomic_store(&_server_state, SERVER_RUNNING); + + while(1) { + struct sockaddr addr; + socklen_t addr_len = sizeof(addr); + + int client_fd = accept(sock, &addr, &addr_len); + if (client_fd < 0) { + log_log(MODULE_NAME, LOG_ERR, "Failed to accept an incoming connection: %s", strerror(errno)); + + continue; + } + + SSL* ssl_conn = SSL_new(_ssl_server_ctx); + SSL_set_fd(ssl_conn, client_fd); + + if (SSL_accept(ssl_conn) <= 0) { + log_log_ssl_err(MODULE_NAME, LOG_ERR, "SSL Handshake failed."); + } else { + + size_t bufsize = 1024; + char* buf = malloc(bufsize); + + if (SSL_read(ssl_conn, buf, bufsize) > 0) { + log_log(MODULE_NAME, LOG_INFO, "Received message: %s", buf); + } + + SSL_write(ssl_conn, message, sizeof(message)); + } + + SSL_shutdown(ssl_conn); + SSL_free(ssl_conn); + close(client_fd); + + } + +_server_thread_shutdown: + close(sock); +} + +enum server_start_return server_start() { + initialize_ssl(); + + if (!_ssl_server_ctx) { + _ssl_server_ctx = SSL_CTX_new(TLS_server_method()); + + if (!_ssl_server_ctx) { + log_log(MODULE_NAME, LOG_ERR, "Failed to create SSL CTX"); + return SERVER_FAIL; + } + + if (!ssl_use_ppcerts(_ssl_server_ctx)) { + log_log(MODULE_NAME, LOG_ERR, "Failed to set certificates."); + SSL_CTX_free(_ssl_server_ctx); + _ssl_server_ctx = NULL; + return SERVER_FAIL; + } + } + + if (!_server_thread) { + _server_thread = malloc(sizeof(pthread_t)); + + if (!_server_thread) { + log_log(MODULE_NAME, LOG_ERR, "Failed to allocate resources for server_thread"); + return SERVER_FAIL; + } + + thread_queue_init(&_server_thread_queue); + pthread_create(_server_thread, NULL, (void* (*)(void*))_server_thread_func, NULL); + + pthread_join(*_server_thread, NULL); + + free(_server_thread); + _server_thread = NULL; + } + + + return SERVER_OK; +} diff --git a/server.h b/server.h new file mode 100644 index 0000000..3590364 --- /dev/null +++ b/server.h @@ -0,0 +1,20 @@ +#ifndef _OAUTH2LIB_SERVER_ +#define _OAUTH2LIB_SERVER_ + +enum server_start_return { + SERVER_FAIL = 0, + SERVER_OK = 1, + SERVER_STARTED +}; + +enum server_state { + SERVER_FAILED, + SERVER_NOT_STARTED, + SERVER_STARTING, + SERVER_RUNNING, + SERVER_FINISHING, +}; + +enum server_start_return server_start(); + +#endif //_OAUTH2LIB_SERVER_ diff --git a/ssl.c b/ssl.c new file mode 100644 index 0000000..c3c93df --- /dev/null +++ b/ssl.c @@ -0,0 +1,183 @@ +#include "ssl.h" +#include "log.h" + +#include +#include + +#include +#include +#include +#include +#include + +#define LOG_NAME "SSL" + +static bool _initialized = false; +static SSL_CTX* _ssl_ctx = NULL; +static EVP_MD* _sha256_ctx = NULL; + +static const char _pkey_file_name[] = "key.pem"; +static const char _cert_file_name[] = "cert.pem"; +static const char _passphrase[] = "This is absolutely insecure."; + +static int _pem_passphrase_cb(char* buf, int size, int rwflag, void* userdata) { + int len = strlen(_passphrase); + memcpy(buf, _passphrase, len < size ? len : size); + return len; +} + +bool _initialize_ssl() { + + SSL_library_init(); + SSL_load_error_strings(); + + _ssl_ctx = SSL_CTX_new(TLS_client_method()); + if (_ssl_ctx == NULL) { + fprintf(stderr, "Failed to initialize Openssl.\n"); + + return false; + } + + SSL_CTX_set_verify(_ssl_ctx, SSL_VERIFY_PEER, NULL); + + if (!SSL_CTX_set_default_verify_paths(_ssl_ctx)) { + fprintf(stderr, "Failed to set default paths for openssl.\n"); + + return false; + } + + if (!SSL_CTX_set_min_proto_version(_ssl_ctx, TLS1_3_VERSION)) { + fprintf(stderr, "Failed to set minimum TLS version.\n"); + + return false; + } + + return true; +} + +bool initialize_ssl() { + if (!_initialized) { + _initialized = _initialize_ssl(); + } + return _initialized; +} + +void ssl_close_ssl() { + + SSL_CTX_free(_ssl_ctx); + EVP_MD_free(_sha256_ctx); + +} + +SSL_CTX* ssl_get_ctx() { return _ssl_ctx; } +EVP_MD* ssl_get_sha256() { return _sha256_ctx; } + + +bool ssl_generate_cert() { + + EVP_PKEY* pkey = NULL; + X509* cert = NULL; + X509_NAME* cert_issuer; + bool success = false; + + pkey = EVP_EC_gen("P-256"); + + if (pkey == NULL) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to generate private key for certificate."); + goto generate_cert_cleanup; + } + + cert = X509_new(); + if (cert == NULL) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to generate certificate."); + goto generate_cert_cleanup; + } + + ASN1_INTEGER_set(X509_get_serialNumber(cert), 1); + X509_gmtime_adj(X509_get_notBefore(cert), 0); + X509_gmtime_adj(X509_get_notAfter(cert), 31536000L); + + cert_issuer = X509_get_subject_name(cert); + + if (cert_issuer == NULL || + !( X509_NAME_add_entry_by_txt(cert_issuer, "C", MBSTRING_ASC, (unsigned char*)"AT", -1, -1, 0) + && X509_NAME_add_entry_by_txt(cert_issuer, "O", MBSTRING_ASC, (unsigned char*)"me", -1, -1, 0) + && X509_NAME_add_entry_by_txt(cert_issuer, "CN", MBSTRING_ASC, (unsigned char*)"localhost", -1, -1, 0))) { + + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to create issuer name."); + goto generate_cert_cleanup; + + } + + if (!X509_set_issuer_name(cert, cert_issuer)) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to set issuer name on certificate."); + goto generate_cert_cleanup; + } + + if (!X509_set_pubkey(cert, pkey)) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to set private key on certificate"); + goto generate_cert_cleanup; + } + + if (!X509_sign(cert, pkey, EVP_sha256())) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to sign certificate."); + goto generate_cert_cleanup; + } + + FILE* key_file = NULL; + key_file = fopen(_pkey_file_name, "wb"); + + if (!PEM_write_PrivateKey(key_file, pkey, EVP_des_ede3_cbc(), (unsigned char*)_passphrase, sizeof(_passphrase), NULL, NULL)) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to write private certificate to disk."); + fclose(key_file); + goto generate_cert_cleanup; + } + fclose(key_file); + + key_file = fopen(_cert_file_name, "wb"); + if (!PEM_write_X509(key_file, cert)) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to write certificate to disk."); + fclose(key_file); + goto generate_cert_cleanup; + } + fclose(key_file); + + success = true; + +generate_cert_cleanup: + + EVP_PKEY_free(pkey); + X509_free(cert); + X509_NAME_free(cert_issuer); + + return success; +} + +bool ssl_use_ppcerts(SSL_CTX* ctx) { + + if (access(_cert_file_name, R_OK) != 0 && access(_pkey_file_name, R_OK) != 0) { + log_log(LOG_NAME, LOG_INFO, "Could not access certificates - Generating new ones"); + bool res = ssl_generate_cert(); + if (!res) { + log_log(LOG_NAME, LOG_ERR, "Failed to generate certificates"); + return false; + } + } + + if (SSL_CTX_use_certificate_chain_file(ctx, _cert_file_name) != 1) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to load SSL certificate."); + return false; + } + + SSL_CTX_set_default_passwd_cb(ctx, _pem_passphrase_cb); + + if (SSL_CTX_use_PrivateKey_file(ctx, _pkey_file_name, SSL_FILETYPE_PEM) != 1) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to load SSL Private Key."); + return false; + } + + SSL_CTX_set_default_passwd_cb(ctx, NULL); + + return true; +} + diff --git a/ssl.h b/ssl.h new file mode 100644 index 0000000..535c762 --- /dev/null +++ b/ssl.h @@ -0,0 +1,17 @@ +#ifndef _OAUTH2LIB_SSL_ +#define _OAUTH2LIB_SSL_ + +#include + +#include + +bool initialize_ssl(); + +SSL_CTX* ssl_get_ctx(); +EVP_MD* ssl_get_sha256(); + +bool ssl_generate_cert(); + +bool ssl_use_ppcerts(SSL_CTX* ctx); + +#endif //_OAUTH2LIB_SSL_ diff --git a/util/a.out b/util/a.out new file mode 100755 index 0000000000000000000000000000000000000000..04e9f5bed92649af73d79a6aa8294cc7b4e9d8cc GIT binary patch literal 20968 zcmeHPeRx#WnLl?XC%MUx%p@TgJ}d(g#9|UapnxC~VDN&VsNjaG9fo9*OieQBWQGqb ziV14dNK5F}wp;cQ+uhb$+q(8+U8OZBOM=_Fv9+4r$3B}Xc9Wu(Qf+CgW`FPbxO3-{ z^lAG%yZ>xXp4|7mpXWXAdG9&re$8FsP3t{|!F2Jms~BMe0Z9od#;~O-03o)F72-ID zUBYsK=LnkQha>=3mO)Q8G!@+oNP6W|83WFcWTv6^kRa)eR}QHLG-RqDq{m4~;cmza zr9}-j-)=7``OF0msQzi_)T7u8R7$;pRoU>6()+xET8~iD8&Y~hO3#Mm5b+h%@uZm0 z|1uR%Hw=MCB_xiFx%KuaJvUsZf(zs90pF%G?K9Qw|2 z=xyWBUj#jfe|L$4FqS_n#-Yz2hyEJqLHxVRjUdc+6A5omCF|PuvScjU+?=S3CX>;< zED~u?MeA;h)HU4}X^6&~Swk`wV=b|kmc*_YX*9=LS&LABEHbf<6q2b(OElgJO5L7l zq#@oKZI0g_V{OTJYpQ{?sU2#rrHzpeCF@dl^`;GL)sokZaV1B`5@q+i}gnu(lT~pHg(O0DCdsNeS|N!LT`j$mFmarSz(fQlA}|qwi3m(Y;D0>=zwv+I z4_4n(IQnOXP$RsFcq2ojC>75TPFs_a z0f9eBIBk(edIkPC;j}dxc}(Ds5Kdc`k)s0t65+H}895~IzayNsC?op?ewc9DCXBQR z{F8*!mSkjyz&i-1t;on0;MAXmZ*J&%;UiYpt5)afvs<=?7xpeZWgT3#6;6(p{aY1d zZ#Pu=XC8nGI(%@|b{Mb@@*4;&Kc6Z?D830H3Pt|tkbmZWl6zJ`$X_Ppm&_(`+51-4 zS?l;)tF7bby_Rv(dg1-lWEj||449)s4gQ(yGWz7-{#CDo&C*wFvpQFOn5tIS>#0Jk zd(}xGqh%J3!*%G)Nq!7fV9k;e>KN420$>o12$_jSf{MxKL@Md>V_*P5sWvnC$qZx=&Fwm*ScXaJEq^BY;GsjOo>eV}>aFgaig{45y29rd_C8HxJWL}C9G3Bauv79sLH8BD`qU-(5z zMF_m81Zpw@TZO=&G`?+63Czg|TpFF8i-A?JR zWTdamNWY1WDcc*I(t9$}GcwXNeoA_}Q+i`YdIURtt4sLZlaYQ?P*pEyK)Nq{8lk{| zZab|a_M{LI5py*2gj4agKf0s$Pfqp;u=~QpbZG^y$qI2YbBV5_DmEen$E}r}XzS($8n4-$R}%PqsUyKcS`V#C*`nj%zk@ z_$wKw&tZ;LgI$W#Kz01p4D~&o`v=QTW~l$2p+23VzL24QEkk`uQ*Fm`osJFsUd05C z6>7&&@4_iBW-6T3ydr$A7t+X3!VT6VeSDkS>}Jg?c$07`A6wnqt8f z16Ki0y27VB_nl_xDOTscewNAyyT5ze=`NBRJ`*C|-Xf}NC(cZgSTDIm33}GLGJM|u zsoRjUBz-5hI>Yd{Pb{=q9qoTQ2+xO~KaYrY!Hdp)11y~80?9g9!%CWaRRV=q4sYr7=2Q`EbTfNHCFd#L4Q(oP0r%a z(bSAAKxod7YRQZ2#*k(9Acsu5%RA8JC(z7eNU(>ISp6zkPl&|2r`LbrIf@w8uq=km zxWT^g2_laiBpFJl`7|CEWZZzvxPFL`UCOwA+ZoCvv@SC4+rOoZJJGrC1WV5tQ#xdh zSePM5LgZb)%DX|4cfds64M2AQc{k|)R7~XE37L2OhBUgxRK^3?5cuyoDcvXE0toXS zUxWWwz2_|d2Yb&8R?LFrY6wb$-a?Zc~c!R&1Uy6FW4_cS^ zVi<)sJ#ij!;KFiRH* z-4P2jhg86MD^nPrV`1;wqGH)Myn+-(Ec%A;lt$=Wgd9xw4c{h7=RrCvwv4Eprv#fV zD3Cbywn%VH8GXawM_o4kBN-@|+NOu>F)0Yk)Hp0|eposGTbqMZVtKefWkCRPz~*D& zJ3f>n=OZGOLk}d{&mC7%Af+*$Ck^ZPFF-!4toCNCLPF)y!AaCOxdn$Y@m#OR*R&39 zkz?Lvk9kJTaGm<{bi44>O3>m$vsvf9v!kh9XlhucF}BfidrBvs=)Q<=_it46#G24I z{Bens7I{QlC>lEQ5;OtNT9@Ixfibdj(2~KFXZDYvea67lU|;vxS2`Gvi82v^|7Q`P zXERi0SG3Pw5v-^$V7JGT30Bcoz$yt#HPNFfJo;gkbZAJVlPtA6!3x^(>kS_ z1r_zd_NGL-xjwiv7F>jFn=H0ia9)O;YKjHhqb;#uv>nn7iDc|Dw1j6~vF6w=fYw-J zG!@?!3yS6%;teT0*vho0Z8LUWJvA9xYfSfENzaU_;o4te^#?T1j*jjDG=427-H!tf z0-uhh@jPG-b_wNJ0ceA<46qDapIX2>0QUeMz#jASfK=#NEfwST8tU`(HsXC4BN z=ri#D3A8zv1lHxzqJwuM7+Z}0-nT|aA0t9hV0}^fM!#>jxu30`cI5}Jm@}Kmq)UF+ zz}MS}P*W7R*R!@T*ONvIl79!tUj+Xst~YGH_lr3sK>A1UKOM`^2%ydPoVN9kfqWYL zd$RcZZT>5uZ$M1$&*I-_^WO#iyWlrw@psz%8HmHX;G1-9uKo9Vy(BOLq?O>G#s>J& zEd8$veIiYii3m(YU?Kt&5txX;LK^J9cCdPjX}t zS1Flyug=p(9c8Kn_Iq>F6)x`lqCQRG`u#bjX1^o$&ikVY66vLD0TkTP7iEX4;3l3Z zLyE626yH;HJQ)$CTtRv-)LE1o8&+3-^6Hhbvr0%`r_lprDi=w>jw*tt4=G1Pbhq4@5IX%>#gix5f4{ zPoAi=oDkTi`B~6S=ItUCBexzPkLpFHKj`u0zkq7aTpT9lFTs)J;1vK?DC#EN0ANL; z9>`gPdf)>jUe3I?0_4)m)t+MS=K%6v!okSNyAv%MhIun#p1e2#BbS5?4<7;W9>T%A z8i&aSC6E)t0JV!NYXIp3V(WPr~Fi9)kndlk{@UJK5EF*(~RRIh&Cn3 zHLx5INkZ92KDg8CUPyRkl9-Mp@yH}8mPrBuN)nGuk`iZ$>9 z%`%-Nd5$E>b0mqF;#894Ig^CbxSNh7DWq}s$n5aQ?C{8xL6aoxkx61chz{mFGo2bS zA3~ism(uW{gDJwQaG1g=lk(|_N_7$C$Bn2NR5322S`BI}|7)nu7wqI%>==j#i|D+$ zpV(9}s!?MF-8klc&hS&Re3^KXyam;~?|X#B6GRpb){q_@*hM$tkoPT5;mexI3LZl> z@2J#yi^xJUpK|etNI%n~?8`=q=tTJXCbiN%N@h`vBey~RQ7}bdJG^-NpFJ*4S5Hq9;}m)yolPy`L|La6UU3gim|uk@Lzr)=B=g>$ zhwl>hTdpGe-kzg(O`=R`xsJl??fLv&qP>Ke9Oy*I{Ta-B_F^%NJv+)jrW$w-R4PR< z51)3kq)&Pn@}}F`@YqdamNyw|XQ@0o9lSl?xr_O?s-^#=P{(M{x!zqs1A@+aQ~2uO z{+TozArR;q%>f%l*M6qf(nh6d9HCbvMAtShpsJ4C#u^nlZ=7vWaYLkDg=Uwb?l1C(&6|IiI>kQS{OP zSC>fRq1-hfFH(bOv8ktk8vX!}a^Hs1W!PiklwCd5rd>L9L+Q$?1|)N*kfcuv;xw5nESXN}S-F=^mBRUp znb%w~%aGC@#J!-{s3d9 zN$@SZ{X#j8MNTyuN2iuw_djsqAQoo^vBW)yrS3tTEC;bHdl0A4AYPQ+9hpYuvSIRk zY85g>CWO0FtOKV9$G)kCM#wP!SX4*=MEWiUm6={znm>epnrJ-6Opo8OTTFo_J}NRk zNnAdk<$LrFVJi1M#XTv$w3@H%;H&$&cO$>JieJ*fz5RR^^W=EVPxHxu#zsDsK!e0O zxN$R|w5fxa|AKo4iIHp8=9*1>u6cm$%`tljRAd23G!Gx+<>t#AA426AW-aRJzo>4R zyZGFcQaJ?WIp!d>b*XHtN`ko(l$(eD665ABbMTnCUHyF+2+u#pjqRQ(2s3@RN~2JV zD2_P_T#v$GNc;o-?a?sBaEL=@G`>rtK?m>mamb;?C5K;!!O`=nx5(fH=BiF!+R3MN z@Pdsz_Zp7gmu)9B{~Ry>CjUSOuXvQ_ujIMrES}fMr%_Z>+}Fc%}F02F~$<6rZ|@&+OoHALRurc^-P~Z{%ei ze8!_Z?=!p@{p;YxO>`{srV%Z#|0SLS>MeXq2id-;htGx|o+TxGi6>ObD>@LiH21-s zhj{T*Ki~xvkSF=Y|H6wl@rz&OzH2$($jgaZ%L4>9@=4oql5|p_JV(uH_2zS-?vW9`hS~-X31nK?7QF zJMy`c^S4QM+A%(>gYvs%BO=PZ-lSQApap^+d{1c9McblvaXcdypF&1x#cod|Q?Yt{ zqL?H!(i+=M&z4(at*Hn-NOtjAeKZwy(IVoT%yeseystxL2Q-?jYg#2gUJh0^*6)rc8-kVBEs$>1C(`sir;3dI0czH_st8>K z0oLEKCeo=0?F=*B$PS16c2#|a8gH$S?V+z-Ddv%s{zg@O*_w*AV>Hx2#Zpyuthy=D z602^Drr`T+)pd#bSal+rPBlg9n&TL{>U0X<%2ub_lk|0Vbz@y!HF_3Vy125fPJO1` z*qW}E_Zh0=t#!@m`k0j0h@-hCbg8~R)=(uX@dThStwgJL##8Otl=i*Y5$|n_wF?z| zcH1a*;S4H&y3jh6!+t&LmdYK6&Y|F8MHg%2hl<`UycVV(#mv{{%e%`#VCYILgmaC;#1-alcb^_eF}h`{402eP2|Y3c&xd{JBKX zi!QjRSUQgU_2ba%K_~y+@%h*|@&`ZTm{re%b>!b1?O5Qy#uPb@?y`5hQy3}W2(m{OsQyh}5?yB2CfOdaPmc=MQXg{Mtce!*wuKA8(DM z+hg_W=M-!K(O#^z9*yp7Z&z)KWn29WL$+jO>*lqQ@U`nAf9mHBvVWK%+W=yA!Dd11 zFj!Ul-j-B!CtxZmVUw<5rH>`sSXFBx6|2J1R@IhFv|$sm*FoBuj^js4;`NGDvt~nO zD%!||cvG~!iB;9_ZAB9jrjnAgE0%1JCt9VC6R?Ftm1sa3!!WR_xPbmz`wY?sFHBP}sqNAz@+$g~N}I>rAggr0+8>QPT6F8tjh zb$8IVO|s5%2@txsfayh)Ww>!gMZ=K5G6&o)$tp{MIraA_eGTViNz)!7yKHdjA5r=m z2D8O4(El6o@rqwU`>&txX{h^8`xz>3|D&MMvpub^pZ{sNUj>Z%NL8&*w|A(-@lP&m z{T=GsUc>iD$Wh$!{}_1mtWN9e=W`kcm9F+*%V~Hg`1I^gVC~Z?Ai5{iFPL7I*#}0n@1; zRC*dd=2S*qv2FNOm%jdeF%9QoDRvgEuW91>8W8Qj-j``us09^t`~Pjw$$uSx7F2=? z_A68a9Y2lJ?QExhe^7EXyhRHt=x+ZM=uZ84v!tMgw-D1&-0lAaJlxt9=KzW6-?h=d zYvq>LxQ71)9lt|gKffQWBx{*`kyr-T`e(qP*sE#7zRs|Jr%d@z`A2Z7;G!oJJ z`uB8t=|Ti$o&@$$>udOXaGd)3x@B1DH`$u9qV+WNK-sCU_t#q%NO|q2Hly`y9Lhw~ z{_hA$&JLyTmZ!FyrBJo6#BG+=r}^Dkw7!P)h(Xt#at|njJDfz+-Le(q=>Jyf|3sfy z%8pxq;W+xIJ|q=>Zl72Zck8bnNB@x(Qhu82gmLTBLs+7@<9FgJssGO|1&wNW6Dm}9 z>wi2X_5WQ0`>6Gt(A7mDyQM~|464MK4U}?0=Yw9q@ +#include +#include + +#define SORTED_STR_SET_INITIAL_CAPACITY 8 +#define SORTED_STR_SET_GROWTH_FACTOR 1.5 + +static ssize_t _get(const sorted_str_set* set, const char* key, int* out_res) { + ssize_t index = 0; + ssize_t low = 0; + ssize_t upp = set->size - 1; + int res = 0; + + while (low <= upp) { + index = low + (upp-low)/2; + + res = strcmp(key, set->data[index].key); + if (res == 0) { + if (out_res) *out_res = res; + return index; + } + + if (res < 0) { + upp = index - 1; + } else { + low = index + 1; + } + } + + if (res > 0) { + ++index; + } + + return -1; +} + + +sorted_str_set* sorted_str_set_new(size_t element_size) { + + sorted_str_set* set = malloc(sizeof(sorted_str_set)); + if (!set) { return NULL; } + + struct sorted_str_set_pair* data = malloc(SORTED_STR_SET_INITIAL_CAPACITY * sizeof(struct sorted_str_set_pair)); + if (!data) { + free(set); + return NULL; + } + + *set = (sorted_str_set) { + .element_size = element_size, + .size = 0, + .capacity = SORTED_STR_SET_INITIAL_CAPACITY, + .data = data, + }; + return set; +} + +void sorted_str_set_free(sorted_str_set* set) { + assert(set); //, "set must not be NULL."); + for (int i = 0; i < set->size; ++i) { + free(set->data[i].key); + free(set->data[i].element); + } + free(set->data); + free(set); +} + +void* sorted_str_set_get(const sorted_str_set* set, const char* key) { + + assert(set); //, "set must not be null"); + assert(key); //, "key must not be null"); + + if (set->size == 0) { + return NULL; + } + + ssize_t i = _get(set, key, NULL); + + if (i == -1) { + return NULL; + } else { + return set->data[i].element; + } +} + +void* sorted_str_set_insert(sorted_str_set* set, const char* key, const void* element) { + + assert(set); //, "set must not be NULL"); + assert(key); //, "key must not be NULL"); + assert(element); //, "element must not be NULL"); + + size_t index = 0; + + if (set->size != 0) { + ssize_t low = 0; + ssize_t upp = set->size - 1; + int res = 0; + + while (low <= upp) { + index = low + (upp-low)/2; + + res = strcmp(key, set->data[index].key); + if (res == 0) { + return set->data[index].element; + } + + if (res < 0) { + upp = index - 1; + } else { + low = index + 1; + } + } + + if (res > 0) { + ++index; + } + } + + if (set->size == set->capacity) { + size_t new_cap = set->capacity * SORTED_STR_SET_GROWTH_FACTOR; + struct sorted_str_set_pair* tmp = reallocarray(set->data, new_cap, sizeof(struct sorted_str_set_pair)); + if (!tmp) { + return NULL; + } + + set->data = tmp; + set->capacity = new_cap; + } + + struct sorted_str_set_pair pair = { + .key = malloc(strlen(key) + 1), + .element = malloc(set->element_size), + }; + if (!pair.key || !pair.element) { + free(pair.key); + free(pair.element); + return NULL; + } + strcpy(pair.key, key); + memcpy(pair.element, element, set->element_size); + + if (set->size - index) { + memmove(&set->data[index + 1], &set->data[index], (set->size - index) * sizeof(struct sorted_str_set_pair)); + } + + set->data[index] = pair; + set->size++; + + return set->data[index].element; +} + +void sorted_str_set_remove(sorted_str_set* set, const char* key) { + + assert(set);//, "set must not be NULL"); + assert(key);//, "key must not be NULL"); + + ssize_t index = _get(set, key, NULL); + if (index == -1) { + return; + } + + struct sorted_str_set_pair* loc = &set->data[index]; + + free(loc->key); + free(loc->element); + + set->size--; + + memmove(loc, loc+1, (set->size - index) * sizeof(struct sorted_str_set_pair)); + +} + diff --git a/util/sorted_str_set.h b/util/sorted_str_set.h new file mode 100644 index 0000000..23184f4 --- /dev/null +++ b/util/sorted_str_set.h @@ -0,0 +1,69 @@ +#ifndef _OAUTH2LIB_UTIL_SORTED_SET_ +#define _OAUTH2LIB_UTIL_SORTED_SET_ + +#include + +struct sorted_str_set_pair { + char* key; + void* element; +}; + +typedef struct { + size_t element_size; + size_t size; + size_t capacity; + + struct sorted_str_set_pair* data; +} sorted_str_set; + + +/** + * @name sorted_str_set_new + * @description Allocates and returns a new instance of sorted_str_set. Must be freed with @sorted_str_set_free + * + * @param element_size the size of each element (not the key) + * @return the newly allocated set or NULL if allocation fails + */ +sorted_str_set* sorted_str_set_new(size_t element_size); + +/** + * @name sorted_str_set_free + * @description Frees all resources of an instance of sorted_str_set. Has to be called before the end of the lifetime of the variable. + * Does not free any memory possibly pointed to by element. + * + * @param set the set to be deallocated. Must not be NULL. + */ +void sorted_str_set_free(sorted_str_set* set); + +/** + * @name sorted_str_set_free + * @description Returns a pointer to the element associated with key within set + * + * @param set the set to be accessed. Must not be NULL. + * @param key the key whose value should be retrieved. Must be null-terminated. + * @return a pointer to the element associated with key or NULL if key is not in the set + */ +void* sorted_str_set_get(const sorted_str_set* set, const char* key); + +/** + * @name sorted_str_set_insert + * @description Inserts a new element associated with key into set and returns a pointer to the element. + * If key is already present in set, nothing is modified and a pointer to the existing element is returned. + * + * @param set the set to be accessed. Must not be NULL. + * @param key the key for the element. Must not be NULL. Must be null-terminated. + * @param element the element to be inserted. A copy of element is made. Is expected to be the size of set->element_size. Must not be NULL. + * @return a pointer to the newly inserted element or if the key is already existing, the element associated with key. + */ +void* sorted_str_set_insert(sorted_str_set* set, const char* key, const void* element); + +/** + * @name sorted_str_set_remove + * @description Removes the key and element associated by key from set. If the key is not present, does nothing. + * + * @param set the set to be accessed. Must not be NULL. + * @aram key the key whose element should be removed. Must not be NULL. + */ +void sorted_str_set_remove(sorted_str_set* set, const char* key); + +#endif //_OAUTH2LIB_UTIL_SORTED_SET_ diff --git a/util/test_sorted_str_set.c b/util/test_sorted_str_set.c new file mode 100644 index 0000000..7cc973c --- /dev/null +++ b/util/test_sorted_str_set.c @@ -0,0 +1,67 @@ +#include "sorted_set.h" + +#include +#include + +#define REF(dtype, value) ((dtype []){ (value) }) + +void print_set(const sorted_str_set* set) { + for (int i = 0; i < set->size; ++i) { + printf("\"%s\": %d\n", set->data[i].key, *(int*)set->data[i].element); + } +} + +int main() { + + sorted_str_set *set = sorted_str_set_new(sizeof(int)); + + printf("%p\n", sorted_str_set_get(set, "zero")); + + printf("-%d\n", strcmp("-one", "one")); + + sorted_str_set_insert(set, "three", REF(int, 3)); + sorted_str_set_insert(set, "one", REF(int, 1)); + sorted_str_set_insert(set, "-one", REF(int, -1)); + sorted_str_set_insert(set, "four", REF(int, 4)); + sorted_str_set_insert(set, "two", REF(int, 2)); + + print_set(set); + + printf("\nsize: %lu\n", set->size); + printf("%d\n", *(int*)sorted_str_set_get(set, "one")); + printf("%d\n", *(int*)sorted_str_set_get(set, "two")); + printf("%d\n", *(int*)sorted_str_set_get(set, "three")); + printf("%d\n", *(int*)sorted_str_set_get(set, "four")); + printf("%d\n", *(int*)sorted_str_set_get(set, "-one")); + + printf("\n%d should be 2\n", *(int*)sorted_str_set_insert(set, "two", REF(int, -2))); + printf("size: %lu (should be the same as before)\n", set->size); + + sorted_str_set_insert(set, "ten", REF(int, 10)); + sorted_str_set_insert(set, "eleven", REF(int, 11)); + sorted_str_set_insert(set, "negative four", REF(int, -4)); + sorted_str_set_insert(set, "fifty", REF(int, 50)); + + printf("%d\n", *(int*)sorted_str_set_get(set, "ten")); + printf("%d\n", *(int*)sorted_str_set_get(set, "eleven")); + printf("%d\n", *(int*)sorted_str_set_get(set, "negative four")); + printf("%d\n", *(int*)sorted_str_set_get(set, "fifty")); + + printf("size: %lu\n", set->size); + + print_set(set); + + sorted_str_set_remove(set, "three"); + sorted_str_set_remove(set, "one"); + sorted_str_set_remove(set, "three"); + sorted_str_set_remove(set, "four"); + sorted_str_set_remove(set, "four"); + + printf("size: %lu\n", set->size); + + print_set(set); + + sorted_str_set_free(set); + + return 0; +} diff --git a/util/thread_queue.c b/util/thread_queue.c new file mode 100644 index 0000000..bf03db3 --- /dev/null +++ b/util/thread_queue.c @@ -0,0 +1,204 @@ +#include "thread_queue.h" + +#include +#include +#include +#include +#include + + +#define MSGPOOL_SIZE 256 + +struct msglist { + struct threadmsg msg; + struct msglist *next; +}; + +static inline struct msglist *get_msglist(struct threadqueue *queue) +{ + struct msglist *tmp; + + if(queue->msgpool != NULL) { + tmp = queue->msgpool; + queue->msgpool = tmp->next; + queue->msgpool_length--; + } else { + tmp = malloc(sizeof *tmp); + } + + return tmp; +} + +static inline void release_msglist(struct threadqueue *queue,struct msglist *node) +{ + + if(queue->msgpool_length > ( queue->length/8 + MSGPOOL_SIZE)) { + free(node); + } else { + node->msg.data = NULL; + node->msg.msgtype = 0; + node->next = queue->msgpool; + queue->msgpool = node; + queue->msgpool_length++; + } + if(queue->msgpool_length > (queue->length/4 + MSGPOOL_SIZE*10)) { + struct msglist *tmp = queue->msgpool; + queue->msgpool = tmp->next; + free(tmp); + queue->msgpool_length--; + } +} + +int thread_queue_init(struct threadqueue *queue) +{ + int ret = 0; + if (queue == NULL) { + return EINVAL; + } + memset(queue, 0, sizeof(struct threadqueue)); + ret = pthread_cond_init(&queue->cond, NULL); + if (ret != 0) { + return ret; + } + + ret = pthread_mutex_init(&queue->mutex, NULL); + if (ret != 0) { + pthread_cond_destroy(&queue->cond); + return ret; + } + + return 0; + +} + +int thread_queue_add(struct threadqueue *queue, void *data, long msgtype) +{ + struct msglist *newmsg; + pthread_mutex_lock(&queue->mutex); + newmsg = get_msglist(queue); + if (newmsg == NULL) { + pthread_mutex_unlock(&queue->mutex); + return ENOMEM; + } + newmsg->msg.data = data; + newmsg->msg.msgtype = msgtype; + + newmsg->next = NULL; + if (queue->last == NULL) { + queue->last = newmsg; + queue->first = newmsg; + } else { + queue->last->next = newmsg; + queue->last = newmsg; + } + + if(queue->length == 0) + pthread_cond_broadcast(&queue->cond); + queue->length++; + pthread_mutex_unlock(&queue->mutex); + + return 0; + +} + +int thread_queue_get(struct threadqueue *queue, const struct timespec *timeout, struct threadmsg *msg) +{ + struct msglist *firstrec; + int ret = 0; + struct timespec abstimeout; + + if (queue == NULL || msg == NULL) { + return EINVAL; + } + if (timeout) { + struct timeval now; + + gettimeofday(&now, NULL); + abstimeout.tv_sec = now.tv_sec + timeout->tv_sec; + abstimeout.tv_nsec = (now.tv_usec * 1000) + timeout->tv_nsec; + if (abstimeout.tv_nsec >= 1000000000) { + abstimeout.tv_sec++; + abstimeout.tv_nsec -= 1000000000; + } + } + + pthread_mutex_lock(&queue->mutex); + + /* Will wait until awakened by a signal or broadcast */ + while (queue->first == NULL && ret != ETIMEDOUT) { //Need to loop to handle spurious wakeups + if (timeout) { + ret = pthread_cond_timedwait(&queue->cond, &queue->mutex, &abstimeout); + } else { + pthread_cond_wait(&queue->cond, &queue->mutex); + + } + } + if (ret == ETIMEDOUT) { + pthread_mutex_unlock(&queue->mutex); + return ret; + } + + firstrec = queue->first; + queue->first = queue->first->next; + queue->length--; + + if (queue->first == NULL) { + queue->last = NULL; // we know this since we hold the lock + queue->length = 0; + } + + + msg->data = firstrec->msg.data; + msg->msgtype = firstrec->msg.msgtype; + msg->qlength = queue->length; + + release_msglist(queue,firstrec); + pthread_mutex_unlock(&queue->mutex); + + return 0; +} + +//maybe caller should supply a callback for cleaning the elements ? +int thread_queue_cleanup(struct threadqueue *queue, int freedata) +{ + struct msglist *rec; + struct msglist *next; + struct msglist *recs[2]; + int ret,i; + if (queue == NULL) { + return EINVAL; + } + + pthread_mutex_lock(&queue->mutex); + recs[0] = queue->first; + recs[1] = queue->msgpool; + for(i = 0; i < 2 ; i++) { + rec = recs[i]; + while (rec) { + next = rec->next; + if (freedata) { + free(rec->msg.data); + } + free(rec); + rec = next; + } + } + + pthread_mutex_unlock(&queue->mutex); + ret = pthread_mutex_destroy(&queue->mutex); + pthread_cond_destroy(&queue->cond); + + return ret; + +} + +long thread_queue_length(struct threadqueue *queue) +{ + long counter; + // get the length properly + pthread_mutex_lock(&queue->mutex); + counter = queue->length; + pthread_mutex_unlock(&queue->mutex); + return counter; + +} diff --git a/util/thread_queue.h b/util/thread_queue.h new file mode 100644 index 0000000..3e196ab --- /dev/null +++ b/util/thread_queue.h @@ -0,0 +1,185 @@ +#ifndef _OAUTH2LIB_UTIL_THREAD_QUEUE_ +#define _OAUTH2LIB_UTIL_THREAD_QUEUE_ + +// https://stackoverflow.com/questions/4577961/pthread-synchronized-blocking-queue + +#include + +/** + * @defgroup ThreadQueue ThreadQueue + * + * Little API for waitable queues, typically used for passing messages + * between threads. + * + */ + +/** + * @mainpage + */ + +/** + * A thread message. + * + * @ingroup ThreadQueue + * + * This is used for passing to #thread_queue_get for retreive messages. + * the date is stored in the data member, the message type in the #msgtype. + * + * Typical: + * @code + * struct threadmsg; + * struct myfoo *foo; + * while(1) + * ret = thread_queue_get(&queue,NULL,&message); + * .. + * foo = msg.data; + * switch(msg.msgtype){ + * ... + * } + * } + * @endcode + * + */ +struct threadmsg{ + /** + * Holds the data. + */ + void *data; + /** + * Holds the messagetype + */ + long msgtype; + /** + * Holds the current queue lenght. Might not be meaningful if there's several readers + */ + long qlength; + +}; + + +/** + * A TthreadQueue + * + * @ingroup ThreadQueue + * + * You should threat this struct as opaque, never ever set/get any + * of the variables. You have been warned. + */ +struct threadqueue { +/** + * Length of the queue, never set this, never read this. + * Use #threadqueue_length to read it. + */ + long length; +/** + * Mutex for the queue, never touch. + */ + pthread_mutex_t mutex; +/** + * Condition variable for the queue, never touch. + */ + pthread_cond_t cond; +/** + * Internal pointers for the queue, never touch. + */ + struct msglist *first,*last; +/** + * Internal cache of msglists + */ + struct msglist *msgpool; +/** + * No. of elements in the msgpool + */ + long msgpool_length; +}; + +/** + * Initializes a queue. + * + * @ingroup ThreadQueue + * + * thread_queue_init initializes a new threadqueue. A new queue must always + * be initialized before it is used. + * + * @param queue Pointer to the queue that should be initialized + * @return 0 on success see pthread_mutex_init + */ +int thread_queue_init(struct threadqueue *queue); + +/** + * Adds a message to a queue + * + * @ingroup ThreadQueue + * + * thread_queue_add adds a "message" to the specified queue, a message + * is just a pointer to a anything of the users choice. Nothing is copied + * so the user must keep track on (de)allocation of the data. + * A message type is also specified, it is not used for anything else than + * given back when a message is retreived from the queue. + * + * @param queue Pointer to the queue on where the message should be added. + * @param data the "message". + * @param msgtype a long specifying the message type, choice of the user. + * @return 0 on succes ENOMEM if out of memory EINVAL if queue is NULL + */ +int thread_queue_add(struct threadqueue *queue, void *data, long msgtype); + +/** + * Gets a message from a queue + * + * @ingroup ThreadQueue + * + * thread_queue_get gets a message from the specified queue, it will block + * the caling thread untill a message arrives, or the (optional) timeout occurs. + * If timeout is NULL, there will be no timeout, and thread_queue_get will wait + * untill a message arrives. + * + * struct timespec is defined as: + * @code + * struct timespec { + * long tv_sec; // seconds + * long tv_nsec; // nanoseconds + * }; + * @endcode + * + * @param queue Pointer to the queue to wait on for a message. + * @param timeout timeout on how long to wait on a message + * @param msg pointer that is filled in with mesagetype and data + * + * @return 0 on success EINVAL if queue is NULL ETIMEDOUT if timeout occurs + */ +int thread_queue_get(struct threadqueue *queue, const struct timespec *timeout, struct threadmsg *msg); + + +/** + * Gets the length of a queue + * + * @ingroup ThreadQueue + * + * threadqueue_length returns the number of messages waiting in the queue + * + * @param queue Pointer to the queue for which to get the length + * @return the length(number of pending messages) in the queue + */ +long thread_queue_length( struct threadqueue *queue ); + +/** + * @ingroup ThreadQueue + * Cleans up the queue. + * + * threadqueue_cleanup cleans up and destroys the queue. + * This will remove all messages from a queue, and reset it. If + * freedata is != 0 free(3) will be called on all pending messages in the queue + * You cannot call this if there are someone currently adding or getting messages + * from the queue. + * After a queue have been cleaned, it cannot be used again untill #thread_queue_init + * has been called on the queue. + * + * @param queue Pointer to the queue that should be cleaned + * @param freedata set to nonzero if free(3) should be called on remaining + * messages + * @return 0 on success EINVAL if queue is NULL EBUSY if someone is holding any locks on the queue + */ +int thread_queue_cleanup(struct threadqueue *queue, int freedata); + +#endif //_OAUTH2LIB_UTIL_THREAD_QUEUE_