diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..94a88aa --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,24 @@ +cmake_minimum_required(VERSION 3.20) + +project(OAuth2Client C) +set(CMAKE_C_STANDARD 23) + +set(CMAKE_BUILD_TYPE "DEBUG") +set(CMAKE_EXPORT_COMPILE_COMMANDS on) + +find_package(OpenSSL REQUIRED) + +add_library(OAuth2Lib + util/thread_queue.c + util/sorted_str_set.c + networking.c + pkce.c + base64.c + ssl.c + log.c + server.c) +target_include_directories(OAuth2Lib PRIVATE ${OpenSSL_INCLUDE_DIRS}) +target_link_libraries(OAuth2Lib PRIVATE OpenSSL::SSL OpenSSL::Crypto) + +add_executable(main main.c) +target_link_libraries(main OAuth2Lib) diff --git a/base64.c b/base64.c new file mode 100644 index 0000000..ac8a99d --- /dev/null +++ b/base64.c @@ -0,0 +1,286 @@ +#include "base64.h" + +#include +#include + +static const char base64url_table[65] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; + +// http://web.mit.edu/freebsd/head/contrib/wpa/src/utils/base64.c +/* + * Base64 encoding/decoding (RFC1341) + * Copyright (c) 2005-2011, Jouni Malinen + * + * This software may be distributed under the terms of the BSD license. + * See README for more details. + */ +static const unsigned char base64_table[65] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +/** +* base64_encode - Base64 encode +* @src: Data to be encoded +* @len: Length of the data to be encoded +* @out_len: Pointer to output length variable, or %NULL if not used +* Returns: Allocated buffer of out_len bytes of encoded data, +* or %NULL on failure +* +* Caller is responsible for freeing the returned buffer. Returned buffer is +* nul terminated to make it easier to use as a C string. The nul terminator is +* not included in out_len. +*/ +char * base64_encode(const uint8_t *src, size_t len, + size_t *out_len) +{ + unsigned char *out, *pos; + const unsigned char *end, *in; + size_t olen; + int line_len; + + olen = len * 4 / 3 + 4; /* 3-byte blocks to 4-byte */ + olen += olen / 72; /* line feeds */ + olen++; /* nul termination */ + if (olen < len) + return NULL; /* integer overflow */ + out = malloc(olen); + if (out == NULL) + return NULL; + + end = src + len; + in = src; + pos = out; + line_len = 0; + while (end - in >= 3) { + *pos++ = base64_table[in[0] >> 2]; + *pos++ = base64_table[((in[0] & 0x03) << 4) | (in[1] >> 4)]; + *pos++ = base64_table[((in[1] & 0x0f) << 2) | (in[2] >> 6)]; + *pos++ = base64_table[in[2] & 0x3f]; + in += 3; + line_len += 4; + if (line_len >= 72) { + *pos++ = '\n'; + line_len = 0; + } + } + + if (end - in) { + *pos++ = base64_table[in[0] >> 2]; + if (end - in == 1) { + *pos++ = base64_table[(in[0] & 0x03) << 4]; + *pos++ = '='; + } else { + *pos++ = base64_table[((in[0] & 0x03) << 4) | + (in[1] >> 4)]; + *pos++ = base64_table[(in[1] & 0x0f) << 2]; + } + *pos++ = '='; + line_len += 4; + } + + if (line_len) + *pos++ = '\n'; + + *pos = '\0'; + if (out_len) + *out_len = pos - out; + return out; +} + + +/** + * base64_decode - Base64 decode + * @src: Data to be decoded + * @len: Length of the data to be decoded + * @out_len: Pointer to output length variable + * Returns: Allocated buffer of out_len bytes of decoded data, + * or %NULL on failure + * + * Caller is responsible for freeing the returned buffer. + */ +uint8_t* base64_decode(const char *src, size_t len, + size_t *out_len) +{ + unsigned char dtable[256], *out, *pos, block[4], tmp; + size_t i, count, olen; + int pad = 0; + + memset(dtable, 0x80, 256); + for (i = 0; i < sizeof(base64_table) - 1; i++) + dtable[base64_table[i]] = (unsigned char) i; + dtable['='] = 0; + + count = 0; + for (i = 0; i < len; i++) { + if (dtable[src[i]] != 0x80) + count++; + } + + if (count == 0 || count % 4) + return NULL; + + olen = count / 4 * 3; + pos = out = malloc(olen); + if (out == NULL) + return NULL; + + count = 0; + for (i = 0; i < len; i++) { + tmp = dtable[src[i]]; + if (tmp == 0x80) + continue; + + if (src[i] == '=') + pad++; + block[count] = tmp; + count++; + if (count == 4) { + *pos++ = (block[0] << 2) | (block[1] >> 4); + *pos++ = (block[1] << 4) | (block[2] >> 2); + *pos++ = (block[2] << 6) | block[3]; + count = 0; + if (pad) { + if (pad == 1) + pos--; + else if (pad == 2) + pos -= 2; + else { + /* Invalid padding */ + free(out); + return NULL; + } + break; + } + } + } + + *out_len = pos - out; + return out; +} + +char * base64url_encode(const uint8_t *src, size_t len, + size_t *out_len) +{ + unsigned char *out, *pos; + const unsigned char *end, *in; + size_t olen; + int line_len; + + olen = len * 4 / 3 + 4; /* 3-byte blocks to 4-byte */ + olen += olen / 72; /* line feeds */ + olen++; /* nul termination */ + if (olen < len) + return NULL; /* integer overflow */ + out = malloc(olen); + if (out == NULL) + return NULL; + + end = src + len; + in = src; + pos = out; + line_len = 0; + while (end - in >= 3) { + *pos++ = base64url_table[in[0] >> 2]; + *pos++ = base64url_table[((in[0] & 0x03) << 4) | (in[1] >> 4)]; + *pos++ = base64url_table[((in[1] & 0x0f) << 2) | (in[2] >> 6)]; + *pos++ = base64url_table[in[2] & 0x3f]; + in += 3; + line_len += 4; + if (line_len >= 72) { + *pos++ = '\n'; + line_len = 0; + } + } + + if (end - in) { + *pos++ = base64url_table[in[0] >> 2]; + if (end - in == 1) { + *pos++ = base64url_table[(in[0] & 0x03) << 4]; + *pos++ = '='; + } else { + *pos++ = base64url_table[((in[0] & 0x03) << 4) | + (in[1] >> 4)]; + *pos++ = base64url_table[(in[1] & 0x0f) << 2]; + } + *pos++ = '='; + line_len += 4; + } + + if (line_len) + *pos++ = '\n'; + + *pos = '\0'; + if (out_len) + *out_len = pos - out; + return out; +} + + +/** + * base64_decode - Base64 decode + * @src: Data to be decoded + * @len: Length of the data to be decoded + * @out_len: Pointer to output length variable + * Returns: Allocated buffer of out_len bytes of decoded data, + * or %NULL on failure + * + * Caller is responsible for freeing the returned buffer. + */ +uint8_t* base64url_decode(const char *src, size_t len, + size_t *out_len) +{ + unsigned char dtable[256], *out, *pos, block[4], tmp; + size_t i, count, olen; + int pad = 0; + + memset(dtable, 0x80, 256); + for (i = 0; i < sizeof(base64url_table) - 1; i++) + dtable[base64url_table[i]] = (unsigned char) i; + dtable['='] = 0; + + count = 0; + for (i = 0; i < len; i++) { + if (dtable[src[i]] != 0x80) + count++; + } + + if (count == 0 || count % 4) + return NULL; + + olen = count / 4 * 3; + pos = out = malloc(olen); + if (out == NULL) + return NULL; + + count = 0; + for (i = 0; i < len; i++) { + tmp = dtable[src[i]]; + if (tmp == 0x80) + continue; + + if (src[i] == '=') + pad++; + block[count] = tmp; + count++; + if (count == 4) { + *pos++ = (block[0] << 2) | (block[1] >> 4); + *pos++ = (block[1] << 4) | (block[2] >> 2); + *pos++ = (block[2] << 6) | block[3]; + count = 0; + if (pad) { + if (pad == 1) + pos--; + else if (pad == 2) + pos -= 2; + else { + /* Invalid padding */ + free(out); + return NULL; + } + break; + } + } + } + + *out_len = pos - out; + return out; +} diff --git a/base64.h b/base64.h new file mode 100644 index 0000000..82d2b41 --- /dev/null +++ b/base64.h @@ -0,0 +1,37 @@ +#ifndef _OSAUTH2LIB_BASE64_ +#define _OSAUTH2LIB_BASE64_ + +#include +#include + +/** +* base64_encode - Base64 encode +* @src: Data to be encoded +* @len: Length of the data to be encoded +* @out_len: Pointer to output length variable, or %NULL if not used +* Returns: Allocated buffer of out_len bytes of encoded data, +* or %NULL on failure +* +* Caller is responsible for freeing the returned buffer. Returned buffer is +* nul terminated to make it easier to use as a C string. The nul terminator is +* not included in out_len. +*/ +char *base64_encode(const uint8_t* src, size_t len, size_t* out_len); + +/** + * base64_decode - Base64 decode + * @src: Data to be decoded + * @len: Length of the data to be decoded + * @out_len: Pointer to output length variable + * Returns: Allocated buffer of out_len bytes of decoded data, + * or %NULL on failure + * + * Caller is responsible for freeing the returned buffer. + */ +uint8_t *base64_decode(const char* src, size_t len, size_t* out_len); + + +char *base64url_encode(const uint8_t* src, size_t len, size_t* out_len); +uint8_t *base64url_decode(const char* src, size_t len, size_t* out_len); + +#endif //_OSAUTH2LIB_BASE64_ diff --git a/log.c b/log.c new file mode 100644 index 0000000..156bf3e --- /dev/null +++ b/log.c @@ -0,0 +1,176 @@ +#include "log.h" + +#include +#include +#include + +#include + +#include "util/sorted_str_set.h" + +static const char* LOG_LEVEL_STRING_TABLE[] = { + "ERROR", + "WARN", + "INFO", + "DEBUG", + "TRACE" +}; + +// Maybe add __attribute__((constructor)) in the future if available to skip _config_initialized +static bool _user_default_config_initialized = false; +static struct log_config_s _user_default_config; + +static sorted_str_set* _module_configs = NULL; + +static inline struct log_config_s _default_config() { + + if (_user_default_config_initialized) { + return _user_default_config; + } + + return (struct log_config_s) { + .loc_error = LOG_LOC_STDS(stderr), + .loc_warn = LOG_LOC_STDS(stderr), + .loc_info = LOG_LOC_STDS(stdout), + .loc_debug = LOG_LOC_NONE(), + .loc_trace = LOG_LOC_NONE(), + + .timestamp = true, + }; +} + +void log_set_default_config(struct log_config_s config) { + _user_default_config = config; + _user_default_config_initialized = true; +} + +struct log_config_s* log_get_config(const char* module_name) { + + assert(module_name != NULL); + + if (!_module_configs) { + if (!(_module_configs = sorted_str_set_new(sizeof(struct log_config_s)))) { + return NULL; + } + } + + struct log_config_s default_conf = _default_config(); + return sorted_str_set_insert(_module_configs, module_name, &default_conf); +} + +void log_log(const char* module, enum LOG_LEVEL log_level, const char* fmt, ...) { + + assert(module != NULL); + + struct log_config_s* config = log_get_config(module); + assert(config); + + struct log_loc* dest_loc; + switch (log_level) { + case LOG_ERR: dest_loc = &config->loc_error; break; + case LOG_WARN: dest_loc = &config->loc_warn; break; + case LOG_INFO: dest_loc = &config->loc_info; break; + case LOG_DEBUG: dest_loc = &config->loc_debug; break; + case LOG_TRACE: dest_loc = &config->loc_trace; break; + default: return; + } + + if (dest_loc->type == LOG_LOC_TYPE_DISABLED) { + return; + } else if (dest_loc->type == LOG_LOC_TYPE_FS) { + + if (!dest_loc->location) { + if (!(dest_loc->location = fopen(dest_loc->file_name, "a"))) { + return; + } + } + + } else if (dest_loc->type != LOG_LOC_TYPE_STDS) { + return; + } + + FILE* dest = dest_loc->location; + + va_list varargs; + va_start(varargs, fmt); + + if (config->timestamp) { + struct timespec utc_time; + struct tm local_time; + char time_str_buf[10+1+8+1]; + timespec_get(&utc_time, TIME_UTC); + localtime_r(&(utc_time.tv_sec), &local_time); + + if(strftime(time_str_buf, sizeof(time_str_buf), "%F %T", &local_time)) { + fprintf(dest, "[%s.%9lu] ", time_str_buf, utc_time.tv_nsec); + } + } + + fprintf(dest, "[%s] %s: ", LOG_LEVEL_STRING_TABLE[log_level], module); + vfprintf(dest, fmt, varargs); + fprintf(dest, "\n"); + + if (dest_loc->type == LOG_LOC_TYPE_FS && !dest_loc->keep_file_open) { + fclose(dest_loc->location); + dest_loc->location = NULL; + } +} + +void log_log_ssl_err(const char* module, enum LOG_LEVEL log_level, const char* fmt, ...) { + assert(module != NULL); + + struct log_config_s* config = log_get_config(module); + assert(config); + + struct log_loc* dest_loc; + switch (log_level) { + case LOG_ERR: dest_loc = &config->loc_error; break; + case LOG_WARN: dest_loc = &config->loc_warn; break; + case LOG_INFO: dest_loc = &config->loc_info; break; + case LOG_DEBUG: dest_loc = &config->loc_debug; break; + case LOG_TRACE: dest_loc = &config->loc_trace; break; + default: return; + } + + if (dest_loc->type == LOG_LOC_TYPE_DISABLED) { + return; + } else if (dest_loc->type == LOG_LOC_TYPE_FS) { + + if (!dest_loc->location) { + if (!(dest_loc->location = fopen(dest_loc->file_name, "a"))) { + return; + } + } + + } else if (dest_loc->type != LOG_LOC_TYPE_STDS) { + return; + } + + FILE* dest = dest_loc->location; + + va_list varargs; + va_start(varargs, fmt); + + if (config->timestamp) { + struct timespec utc_time; + struct tm local_time; + char time_str_buf[10+1+8+1]; + timespec_get(&utc_time, TIME_UTC); + localtime_r(&(utc_time.tv_sec), &local_time); + + if(strftime(time_str_buf, sizeof(time_str_buf), "%F %T", &local_time)) { + fprintf(dest, "[%s.%9lu] ", time_str_buf, utc_time.tv_nsec); + } + } + + fprintf(dest, "[%s] %s: ", LOG_LEVEL_STRING_TABLE[log_level], module); + vfprintf(dest, fmt, varargs); + fprintf(dest, "\n"); + ERR_print_errors_fp(dest); + + if (dest_loc->type == LOG_LOC_TYPE_FS && !dest_loc->keep_file_open) { + fclose(dest_loc->location); + dest_loc->location = NULL; + } + +} diff --git a/log.c.old b/log.c.old new file mode 100644 index 0000000..bb69cc9 --- /dev/null +++ b/log.c.old @@ -0,0 +1,97 @@ +#include "log.h" + +#include +#include +#include + +#include + +static const char* LOG_LEVEL_STRING_TABLE[] = { + "ERROR", + "WARN", + "INFO", + "DEBUG", + "TRACE" +}; + +// Maybe add __attribute__((constructor)) in the future if available to skip _config_initialized +static bool _user_default_config_initialized = false; +static struct log_config_s _user_default_config; + +static inline struct log_config_s _default_config() { + + if (_user_default_config_initialized) { + return _user_default_config; + } + + return (struct log_config_s) { + .loc_error = stderr, + .loc_warn = stderr, + .loc_info = stdout, + .loc_debug = NULL, + .loc_trace = NULL, + + .timestamp = true, + }; +} + +void log_set_default_config(struct log_config_s config) { + _user_default_config = config; + _user_default_config_initialized = true; +} + +struct logger_s log_new(const char* module_name) { + return (struct logger_s) { + .module_name = module_name, + .config = _default_config(), + }; +} + +void log_log(const struct logger_s* logger, enum LOG_LEVEL log_level, const char* fmt, ...) { + + assert(logger != NULL); + + FILE* dest = NULL; + switch (log_level) { + case LOG_ERR: dest = logger->config.loc_error; break; + case LOG_WARN: dest = logger->config.loc_warn; break; + case LOG_INFO: dest = logger->config.loc_info; break; + case LOG_DEBUG: dest = logger->config.loc_debug; break; + case LOG_TRACE: dest = logger->config.loc_trace; break; + default: break; + } + + if (dest == NULL) { + return; + } + + va_list varargs; + va_start(varargs, fmt); + + if (logger->config.timestamp) { + struct timespec utc_time; + struct tm local_time; + char time_str_buf[10+1+8+1]; + timespec_get(&utc_time, TIME_UTC); + localtime_r(&(utc_time.tv_sec), &local_time); + + if(strftime(time_str_buf, sizeof(time_str_buf), "%F %T", &local_time)) { + fprintf(dest, "[%s.%9lu] ", time_str_buf, utc_time.tv_nsec); + } + } + + fprintf(dest, "[%s] %s: ", LOG_LEVEL_STRING_TABLE[log_level], logger->module_name); + vfprintf(dest, fmt, varargs); + fprintf(dest, "\n"); +} + +void log_log_ssl_err(const struct logger_s* logger, enum LOG_LEVEL log_level, unsigned long error, const char* message_fmt, ...) { + char error_buf[256]; + ERR_error_string_n(ERR_get_error(), error_buf, sizeof(error_buf)); + + va_list varargs; + va_start(varargs, message_fmt); + + log_log(logger, log_level, message_fmt, varargs); + log_log(logger, log_level, "%s", error_buf); +} diff --git a/log.h b/log.h new file mode 100644 index 0000000..f986089 --- /dev/null +++ b/log.h @@ -0,0 +1,54 @@ +#ifndef _OAUTH2LIB_LOG_ +#define _OAUTH2LIB_LOG_ + +#include +#include + +enum LOG_LEVEL { + LOG_ERR, + LOG_WARN, + LOG_INFO, + LOG_DEBUG, + LOG_TRACE, +}; + +enum LOG_LOC_TYPE { + LOG_LOC_TYPE_DISABLED, + LOG_LOC_TYPE_FS, + LOG_LOC_TYPE_STDS +}; + +struct log_loc { + enum LOG_LOC_TYPE type; + FILE* location; + const char* file_name; + bool keep_file_open; +}; + +#define LOG_LOC_FILE(file_location, keep_open) ((struct log_loc) { .type=LOG_LOC_TYPE_FS, .location=NULL, .file_name=file_location, .keep_file_open=keep_open }) +#define LOG_LOC_STDS(stream) ((struct log_loc) { .type=LOG_LOC_TYPE_STDS, .location=stream, .file_name=NULL }) +#define LOG_LOC_NONE() ((struct log_loc) { .type=LOG_LOC_TYPE_DISABLED }) + +struct log_config_s { + struct log_loc loc_error; + struct log_loc loc_warn; + struct log_loc loc_info; + struct log_loc loc_debug; + struct log_loc loc_trace; + + bool timestamp; +}; + +struct logger_s { + const char* module_name; + struct log_config_s config; +}; + +void log_set_default_config(struct log_config_s config); +struct log_config_s* log_get_config(const char* module_name); + +void log_log(const char* module, enum LOG_LEVEL log_level, const char* fmt, ...); + +void log_log_ssl_err(const char* module, enum LOG_LEVEL log_level, const char* message_fmt, ...); + +#endif //_OAUTH2LIB_LOG_ diff --git a/log.h.old b/log.h.old new file mode 100644 index 0000000..f34cba2 --- /dev/null +++ b/log.h.old @@ -0,0 +1,38 @@ +#ifndef _OAUTH2LIB_LOG_ +#define _OAUTH2LIB_LOG_ + +#include +#include +#include + +enum LOG_LEVEL { + LOG_ERR, + LOG_WARN, + LOG_INFO, + LOG_DEBUG, + LOG_TRACE, +}; + +struct log_config_s { + FILE* loc_error; + FILE* loc_warn; + FILE* loc_info; + FILE* loc_debug; + FILE* loc_trace; + + bool timestamp; +}; + +struct logger_s { + const char* module_name; + struct log_config_s config; +}; + +void log_set_default_config(struct log_config_s config); +struct logger_s log_new(const char* module_name); + +void log_log(const struct logger_s* logger, enum LOG_LEVEL log_level, const char* fmt, ...); + +void log_log_ssl_err(const struct logger_s* logger, enum LOG_LEVEL log_level, unsigned long error, const char* message_fmt, ...); + +#endif //_OAUTH2LIB_LOG_ diff --git a/main.c b/main.c new file mode 100644 index 0000000..004aedc --- /dev/null +++ b/main.c @@ -0,0 +1,23 @@ +#include + +#include "networking.h" +#include "pkce.h" +#include "log.h" +#include "server.h" +#include "ssl.h" + +#include + +int main(int argc, const char* argv[]) { + + /*SSL_CTX* ctx = SSL_CTX_new(TLS_server_method()); + ssl_use_ppcerts(ctx);*/ + + // tcp_get(argv[1]); + + //do_pkce(); + + server_start(); + + return 0; +} diff --git a/main.c.old b/main.c.old new file mode 100644 index 0000000..d402245 --- /dev/null +++ b/main.c.old @@ -0,0 +1,131 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//#include "dns.h" + +#define WEB_ADDR "www.onedrive.com" + +struct socket_data_chunk { + char data[4096]; + struct socket_data_chunk *next; +}; + +void free_socketdata(struct socket_data_chunk* data_chunk) { + for (struct socket_data_chunk *iter = data_chunk; iter != NULL; ) { + struct socket_data_chunk *data = iter; + iter = iter->next; + + free(data); + } +} + +int strncmp_end(const char *str, size_t max_strlen, const char *suffix, size_t max_suffix_len) +{ + if (!str || !suffix) + return 0; + size_t lenstr = strnlen(str, max_strlen); + size_t lensuffix = strnlen(suffix, max_suffix_len); + if (lensuffix > lenstr) + return 0; + return strncmp(str + lenstr - lensuffix, suffix, lensuffix); +} + + +int main() { + + //oauth2_dns_resolve("onedrive.com"); + + struct addrinfo resolve_hints = { + .ai_family = AF_INET, //AF_UNSPEC, // IPv4 or IPv6 + .ai_socktype = SOCK_STREAM, + .ai_flags = 0, + .ai_protocol = 0 + }; + + struct addrinfo *resolve_result; + + int get_addr_res = getaddrinfo(WEB_ADDR, NULL, &resolve_hints, &resolve_result); + if (get_addr_res != 0) { + printf("getaddrinfo failed: %d\n", get_addr_res); + exit(-1); + } + char ip_addr[50]; + printf("IP address: %s\n", inet_ntop(AF_INET, &((struct sockaddr_in*) resolve_result->ai_addr)->sin_addr, ip_addr, sizeof(ip_addr))); + printf("Port: %d\n", ntohs(((struct sockaddr_in*)resolve_result->ai_addr)->sin_port)); + + ((struct sockaddr_in*)resolve_result->ai_addr)->sin_port = htons(80); + + int sock_desc = socket(AF_INET, SOCK_STREAM, 0); + if (sock_desc == -1) { + printf("Socket creation failed. Errno: %d\n", errno); + exit(-1); + } + + printf("created socket\n"); + + struct sockaddr_in server_addr = *(struct sockaddr_in *)resolve_result->ai_addr; + +// struct sockaddr_in server_addr = { +// .sin_family = AF_INET, +// .sin_addr = inet_addr("127.0.0.1"), +// .sin_port = htons(8080) +// }; + + if (connect(sock_desc, (struct sockaddr*)&server_addr, sizeof(server_addr)) != 0) { + printf("Socket connection failed. Errno: %d\n", errno); + exit(-1); + close(sock_desc); + } + + printf("Connected to address\n"); + + const char *msg = + "GET / HTTP/1.1\r\n" + "Host: " WEB_ADDR "\r\n" + "\r\n"; + write(sock_desc, msg, strlen(msg)); + + printf("sent: \n%s\n", msg); + + struct socket_data_chunk *data = NULL; + struct socket_data_chunk **it = &data; + do { + struct socket_data_chunk *nit = calloc(1, sizeof(struct socket_data_chunk)); + if (!nit) { + printf("failed to allocate data for receive\n"); + exit(-1); + } + + if (*it) { + (*it)->next = nit; + } else { + *it = nit; + } + it = &nit; + + ssize_t bytes_read = recv(sock_desc, &nit->data, sizeof(nit->data), 0); + if (bytes_read == -1) { + printf("Failed to recieve data. Errno: %d", errno); + exit(-1); + } + + printf("received (%ld) bytes:\n%s\n", bytes_read, nit->data); + + //} while (!strstr((*it)->data, "\r\n\r\n")); + } while(strncmp_end(data->data, sizeof((*it)->data), "\r\n\r\n", 5)); + + printf("Finishing\n"); + + close(sock_desc); + freeaddrinfo(resolve_result); + free_socketdata(data); + + return 0; +} diff --git a/networking.c b/networking.c new file mode 100644 index 0000000..38c9fcd --- /dev/null +++ b/networking.c @@ -0,0 +1,221 @@ +#include "networking.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include "ssl.h" + +typedef int socket_fd; + +static SSL_CTX *ssl_ctx = NULL; + +char *strstr_fixn(const char *haystack, const char *needle, size_t len) { + size_t needle_len; + + if (0 == (needle_len = strnlen(needle, len))) + return (char *)haystack; + + for (size_t i = 0; i <= len - needle_len; i++) + { + if ((haystack[0] == needle[0]) && + (0 == strncmp(haystack, needle, needle_len))) + return (char *)haystack; + + haystack++; + } + return NULL; +} + + + +static socket_fd connect_to_http(const char* hostname) { + + struct addrinfo resolve_hints = { + .ai_family = AF_UNSPEC, + .ai_socktype = SOCK_STREAM, + }; + + struct addrinfo* resolve_result; + int result = getaddrinfo(hostname, "https", &resolve_hints, &resolve_result); + + if (result != 0) { + fprintf(stderr, "Failed to get IP address of \"%s\" for an http connection. Error: %s\n", hostname, gai_strerror(result)); + return -1; + } + + socket_fd sock_desc = -1; + for (struct addrinfo *it = resolve_result; it != NULL; it = it->ai_next) { + + sock_desc = socket(it->ai_family, it->ai_socktype, it->ai_protocol); + if (sock_desc == -1) { + continue; + } + + if (connect(sock_desc, it->ai_addr, it->ai_addrlen) != -1) { + uint16_t port = ntohs( + (it->ai_family == AF_INET) + ? ((struct sockaddr_in*) it->ai_addr)->sin_port + : ((struct sockaddr_in6*) it->ai_addr)->sin6_port); + printf("Connected to \"%s\" on port %d\n", hostname, port); + break; + } + close(sock_desc); + } + + freeaddrinfo(resolve_result); + + return sock_desc; +} +void tcp_get(const char* addr) { + + initialize_ssl(); + printf("initialized ssl\n"); + + + socket_fd sock_fd = connect_to_http(addr); + + printf("Got socket: %d\n", sock_fd); + + SSL* ssl_conn = SSL_new(ssl_ctx); + if (!ssl_conn) { + fprintf(stderr, "Failed to create an SSL connection provider.\n"); + ERR_print_errors_fp(stderr); + + close(sock_fd); + return; + } + + SSL_set_fd(ssl_conn, sock_fd); + + if (!SSL_set_tlsext_host_name(ssl_conn, addr)) { + fprintf(stderr, "Failed to set SNI hostname.\n"); + ERR_print_errors_fp(stderr); + + SSL_free(ssl_conn); + return; + } + + if (!SSL_set1_host(ssl_conn, addr)) { + fprintf(stderr, "Failed to set certificate hostname.\n"); + ERR_print_errors_fp(stderr); + + SSL_free(ssl_conn); + return; + } + + if (SSL_connect(ssl_conn) < 1) { + fprintf(stderr, "Failed to connect to server.\n"); + + if (SSL_get_verify_result(ssl_conn) != X509_V_OK) { + fprintf(stderr, "Verification error: %s\n", X509_verify_cert_error_string(SSL_get_verify_result(ssl_conn))); + } else { + ERR_print_errors_fp(stderr); + } + + SSL_free(ssl_conn); + return; + } + + printf("Connected to peer.\n"); + + const char* HTTP_HEAD = "GET / HTTP/1.1\r\n" + "Connection: close\r\n" + "Host: "; + const char* HTTP_HEAD_END = "\r\n\r\n"; + + const char* msg[] = { HTTP_HEAD, addr, HTTP_HEAD_END }; + + for (size_t i = 0; i < sizeof(msg)/sizeof(msg[0]); ++i) { + + if (!SSL_write(ssl_conn, msg[i], strlen(msg[i]))) { + fprintf(stderr, "Failed to send message: \"%s\"\n", msg[i]); + ERR_print_errors_fp(stderr); + + SSL_free(ssl_conn); + return; + } + } + + printf("Sent HTTP request to peer.\n"); + + size_t read_bytes; + char read_buf[160]; + + enum { HTTP_CONTENT_LENGTH, HTTP_CONTENT_CHUNKED } http_content_transmit_type; + size_t total_read_bytes; + size_t expected_bytes = -1; + bool body = false; + + char last_3_of_prev[3] = {}; + + printf("\n"); + while (SSL_read_ex(ssl_conn, read_buf, sizeof(read_buf), &read_bytes)) { + total_read_bytes += read_bytes; + + if ((http_content_transmit_type == HTTP_CONTENT_CHUNKED && read_buf[0] == '0') + || http_content_transmit_type == HTTP_CONTENT_LENGTH && total_read_bytes == expected_bytes) { + SSL_set_options(ssl_conn, SSL_OP_IGNORE_UNEXPECTED_EOF); + } + + if (!body) { + const char* content_length_str = "Content-Length: "; + const char* transfer_encoding_str = "Transfer-Encoding: chunked"; + + char* content_length_start = strstr_fixn(read_buf, content_length_str, read_bytes); + if (content_length_start != NULL) { + http_content_transmit_type = HTTP_CONTENT_LENGTH; + + content_length_start += strlen(content_length_str); + expected_bytes = strtol(content_length_start, NULL, 10); + } + + if (strstr_fixn(read_buf, transfer_encoding_str, read_bytes)) { + http_content_transmit_type = HTTP_CONTENT_CHUNKED; + } + + char concat[] = { last_3_of_prev[0], last_3_of_prev[1], last_3_of_prev[2], read_buf[0], read_buf[1], read_buf[2], 0 }; + char *concat_start; + if ((concat_start = strstr(concat, HTTP_HEAD_END))) { + body = true; + total_read_bytes = read_bytes - (concat_start - concat + 1); + } else { + char* body_start = strstr_fixn(read_buf, HTTP_HEAD_END, read_bytes); + if (body_start) { + body = true; + total_read_bytes = read_bytes - (body_start + strlen(HTTP_HEAD_END) - read_buf); + } + } + } + + + fwrite(read_buf, 1, read_bytes, stdout); + } + printf("\n"); + + if (SSL_get_error(ssl_conn, 0) != SSL_ERROR_ZERO_RETURN) { + fprintf(stderr, "Failed to read response.\n"); + ERR_print_errors_fp(stderr); + } + + printf("Recieved %ld bytes", total_read_bytes); + printf(expected_bytes != -1 ? ", expected: %ld bytes\n" : "\n", expected_bytes); + printf("Ignore unexpected eof? %s\n", SSL_get_options(ssl_conn) & SSL_OP_IGNORE_UNEXPECTED_EOF ? "true" : "false"); + + if (SSL_shutdown(ssl_conn) < 1) { + fprintf(stderr, "Failed to shut down connection gracefully.\n"); + ERR_print_errors_fp(stderr); + } + + SSL_free(ssl_conn); + + printf("end\n"); +} diff --git a/networking.h b/networking.h new file mode 100644 index 0000000..caf3254 --- /dev/null +++ b/networking.h @@ -0,0 +1,11 @@ +#ifndef _OAUTH2LIB_NETWORKING_ +#define _OAUTH2LIB_NETWORKING_ + +struct content { + +}; + +void tcp_get(const char* addr); + + +#endif //_OAUTH2LIB_NETWORKING_ diff --git a/pkce.c b/pkce.c new file mode 100644 index 0000000..4f44e45 --- /dev/null +++ b/pkce.c @@ -0,0 +1,145 @@ +#include "pkce.h" + +#include +#include + +#include +#include +#include + +#include "base64.h" +#include "log.h" +#include "ssl.h" + +#define MODULE_NAME "PKCE" + +#define CODE_VERIFIER_LENGTH 32 // the length of the binary random string before base64url encoding + +struct code_verifier_s { + size_t length; + char* key; +}; + + + +static inline void free_code_verifier(struct code_verifier_s* ptr) { free(ptr->key); } + +static struct code_verifier_s generate_code_verifier() { + + uint8_t code_buffer[CODE_VERIFIER_LENGTH]; + + int gen_result = RAND_priv_bytes(code_buffer, sizeof(code_buffer)); + if (gen_result != 1) { + log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to generate code_verifier.\n"); + + return (struct code_verifier_s) { 0, NULL }; + } + + size_t encoded_size = 0; + char* encoded = base64url_encode(code_buffer, sizeof(code_buffer), &encoded_size); + if (!encoded) { + log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to encode code_verifier.\n"); + return (struct code_verifier_s) { 0, NULL }; + } + + // trim trailing =\n + encoded_size -= 2; + encoded[encoded_size] = '\0'; + + return (struct code_verifier_s) { encoded_size, encoded }; +} + +static uint8_t* sha256_encode(const char* str, size_t str_len) { + + EVP_MD_CTX* hash_ctx = EVP_MD_CTX_new(); + if (!hash_ctx) { + log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to create Crypto context.\n"); + + return NULL; + } + + EVP_MD* sha256_ctx = ssl_get_sha256(); + if (!sha256_ctx) { + EVP_MD_CTX_free(hash_ctx); + return NULL; + } + + if (!EVP_DigestInit_ex(hash_ctx, sha256_ctx, NULL)) { + log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to initialize sha256 digest.\n"); + + EVP_MD_CTX_free(hash_ctx); + return NULL; + } + + if (!EVP_DigestUpdate(hash_ctx, str, str_len)) { + log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to set message to digest.\n"); + + EVP_MD_CTX_free(hash_ctx); + return NULL; + } + + uint8_t *hash_buf = malloc(EVP_MD_get_size(sha256_ctx)); + if (!hash_buf) { + log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to allocate memory for digest output.\n"); + + EVP_MD_CTX_free(hash_ctx); + return NULL; + } + + unsigned int len = 0; + if (!EVP_DigestFinal_ex(hash_ctx, hash_buf, &len)) { + log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to compute hash.\n"); + + EVP_MD_CTX_free(hash_ctx); + return NULL; + } + + EVP_MD_CTX_free(hash_ctx); + + return hash_buf; +} + +char* _construct_get(char* server_url, char* client_id, char* redirect_uri, char* scope, char* state) { + + static char msg_template[] = "GET %s?response_type=code&client_id=%s&state=%s&redirect_uri=%s HTTP/1.1\r\nHost: localhost\r\n\r\n"; + + size_t len = strlen(msg_template) - 4*2 + strlen(server_url)+strlen(client_id)+strlen(redirect_uri)+strlen(scope)+strlen(state) + 1; + + char* buf = malloc(len); + if (!buf) { return NULL; } + + if (snprintf(buf, len, msg_template, server_url, client_id, state, redirect_uri) < 0) { + free(buf); + return NULL; + } + + return buf; +} + +void do_pkce() { + + log_log("PKCE", LOG_INFO, "generating code verifier"); + + struct code_verifier_s code_verifier = generate_code_verifier(); + + log_log("PKCE", LOG_DEBUG, "generated verifer: %s\n", code_verifier.key); + log_log("PKCE", LOG_INFO, "generating code challenge"); + + uint8_t* code_challenge_sha256 = sha256_encode(code_verifier.key, code_verifier.length); + if (!code_challenge_sha256) { + log_log("PKCE", LOG_ERR, "Failed to encode code challenge"); + goto pkce_cleanup; + } + + size_t code_challenge_length = 0; + char* code_challenge = base64url_encode(code_challenge_sha256, EVP_MD_get_size(ssl_get_sha256()), &code_challenge_length); + + code_challenge_length -= 2; + code_challenge[code_challenge_length] = '\0'; + + log_log("PKCE", LOG_DEBUG, "generated code challenge: %s\n", code_challenge); + +pkce_cleanup: + free_code_verifier(&code_verifier); + free(code_challenge); +} diff --git a/pkce.h b/pkce.h new file mode 100644 index 0000000..037d9b1 --- /dev/null +++ b/pkce.h @@ -0,0 +1,6 @@ +#ifndef _OAUTH2LIB_PKCE_ +#define _OAUTH2LIB_PKCE_ + +void do_pkce(); + +#endif //_OAUTH2LIB_PKCE_ diff --git a/server.c b/server.c new file mode 100644 index 0000000..9b1ce92 --- /dev/null +++ b/server.c @@ -0,0 +1,141 @@ +#include "server.h" + +#include +#include +#include +#include + +#include + +#include "log.h" +#include "ssl.h" + +#include "util/thread_queue.h" + +#define MODULE_NAME "login_server" + +#define SERVER_PORT 443 + +static const char message[] = "

Hello World!

"; + +static SSL_CTX* _ssl_server_ctx = NULL; +static struct logger_s _logger; + +static pthread_t *_server_thread = NULL; +static struct threadqueue _server_thread_queue; + +_Atomic enum server_state _server_state = SERVER_NOT_STARTED; + +static enum server_start_return _server_thread_func(void* data) { + + atomic_store(&_server_state, SERVER_STARTING); + + struct sockaddr_in sockaddr = { + .sin_family = AF_INET, + .sin_port = htons(SERVER_PORT), + .sin_addr = { + .s_addr = htonl(INADDR_ANY) + } + }; + + int sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) { + log_log(MODULE_NAME, LOG_ERR, "Failed to create socket: %s", strerror(errno)); + atomic_store(&_server_state, SERVER_FAILED); + return SERVER_FAIL; + } + + if (bind(sock, (struct sockaddr*)&sockaddr, sizeof(sockaddr)) < 0) { + log_log(MODULE_NAME, LOG_ERR, "Failed to bind socket: %s", strerror(errno)); + atomic_store(&_server_state, SERVER_FAILED); + close(sock); + return SERVER_FAIL; + } + + if (listen(sock, 1) != 0) { + log_log(MODULE_NAME, LOG_ERR, "Failed to start listening on socket: %s", strerror(errno)); + atomic_store(&_server_state, SERVER_FAILED); + close(sock); + return SERVER_FAIL; + } + + log_log(MODULE_NAME, LOG_INFO, "Waiting for incoming connections..."); + atomic_store(&_server_state, SERVER_RUNNING); + + while(1) { + struct sockaddr addr; + socklen_t addr_len = sizeof(addr); + + int client_fd = accept(sock, &addr, &addr_len); + if (client_fd < 0) { + log_log(MODULE_NAME, LOG_ERR, "Failed to accept an incoming connection: %s", strerror(errno)); + + continue; + } + + SSL* ssl_conn = SSL_new(_ssl_server_ctx); + SSL_set_fd(ssl_conn, client_fd); + + if (SSL_accept(ssl_conn) <= 0) { + log_log_ssl_err(MODULE_NAME, LOG_ERR, "SSL Handshake failed."); + } else { + + size_t bufsize = 1024; + char* buf = malloc(bufsize); + + if (SSL_read(ssl_conn, buf, bufsize) > 0) { + log_log(MODULE_NAME, LOG_INFO, "Received message: %s", buf); + } + + SSL_write(ssl_conn, message, sizeof(message)); + } + + SSL_shutdown(ssl_conn); + SSL_free(ssl_conn); + close(client_fd); + + } + +_server_thread_shutdown: + close(sock); +} + +enum server_start_return server_start() { + initialize_ssl(); + + if (!_ssl_server_ctx) { + _ssl_server_ctx = SSL_CTX_new(TLS_server_method()); + + if (!_ssl_server_ctx) { + log_log(MODULE_NAME, LOG_ERR, "Failed to create SSL CTX"); + return SERVER_FAIL; + } + + if (!ssl_use_ppcerts(_ssl_server_ctx)) { + log_log(MODULE_NAME, LOG_ERR, "Failed to set certificates."); + SSL_CTX_free(_ssl_server_ctx); + _ssl_server_ctx = NULL; + return SERVER_FAIL; + } + } + + if (!_server_thread) { + _server_thread = malloc(sizeof(pthread_t)); + + if (!_server_thread) { + log_log(MODULE_NAME, LOG_ERR, "Failed to allocate resources for server_thread"); + return SERVER_FAIL; + } + + thread_queue_init(&_server_thread_queue); + pthread_create(_server_thread, NULL, (void* (*)(void*))_server_thread_func, NULL); + + pthread_join(*_server_thread, NULL); + + free(_server_thread); + _server_thread = NULL; + } + + + return SERVER_OK; +} diff --git a/server.h b/server.h new file mode 100644 index 0000000..3590364 --- /dev/null +++ b/server.h @@ -0,0 +1,20 @@ +#ifndef _OAUTH2LIB_SERVER_ +#define _OAUTH2LIB_SERVER_ + +enum server_start_return { + SERVER_FAIL = 0, + SERVER_OK = 1, + SERVER_STARTED +}; + +enum server_state { + SERVER_FAILED, + SERVER_NOT_STARTED, + SERVER_STARTING, + SERVER_RUNNING, + SERVER_FINISHING, +}; + +enum server_start_return server_start(); + +#endif //_OAUTH2LIB_SERVER_ diff --git a/ssl.c b/ssl.c new file mode 100644 index 0000000..c3c93df --- /dev/null +++ b/ssl.c @@ -0,0 +1,183 @@ +#include "ssl.h" +#include "log.h" + +#include +#include + +#include +#include +#include +#include +#include + +#define LOG_NAME "SSL" + +static bool _initialized = false; +static SSL_CTX* _ssl_ctx = NULL; +static EVP_MD* _sha256_ctx = NULL; + +static const char _pkey_file_name[] = "key.pem"; +static const char _cert_file_name[] = "cert.pem"; +static const char _passphrase[] = "This is absolutely insecure."; + +static int _pem_passphrase_cb(char* buf, int size, int rwflag, void* userdata) { + int len = strlen(_passphrase); + memcpy(buf, _passphrase, len < size ? len : size); + return len; +} + +bool _initialize_ssl() { + + SSL_library_init(); + SSL_load_error_strings(); + + _ssl_ctx = SSL_CTX_new(TLS_client_method()); + if (_ssl_ctx == NULL) { + fprintf(stderr, "Failed to initialize Openssl.\n"); + + return false; + } + + SSL_CTX_set_verify(_ssl_ctx, SSL_VERIFY_PEER, NULL); + + if (!SSL_CTX_set_default_verify_paths(_ssl_ctx)) { + fprintf(stderr, "Failed to set default paths for openssl.\n"); + + return false; + } + + if (!SSL_CTX_set_min_proto_version(_ssl_ctx, TLS1_3_VERSION)) { + fprintf(stderr, "Failed to set minimum TLS version.\n"); + + return false; + } + + return true; +} + +bool initialize_ssl() { + if (!_initialized) { + _initialized = _initialize_ssl(); + } + return _initialized; +} + +void ssl_close_ssl() { + + SSL_CTX_free(_ssl_ctx); + EVP_MD_free(_sha256_ctx); + +} + +SSL_CTX* ssl_get_ctx() { return _ssl_ctx; } +EVP_MD* ssl_get_sha256() { return _sha256_ctx; } + + +bool ssl_generate_cert() { + + EVP_PKEY* pkey = NULL; + X509* cert = NULL; + X509_NAME* cert_issuer; + bool success = false; + + pkey = EVP_EC_gen("P-256"); + + if (pkey == NULL) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to generate private key for certificate."); + goto generate_cert_cleanup; + } + + cert = X509_new(); + if (cert == NULL) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to generate certificate."); + goto generate_cert_cleanup; + } + + ASN1_INTEGER_set(X509_get_serialNumber(cert), 1); + X509_gmtime_adj(X509_get_notBefore(cert), 0); + X509_gmtime_adj(X509_get_notAfter(cert), 31536000L); + + cert_issuer = X509_get_subject_name(cert); + + if (cert_issuer == NULL || + !( X509_NAME_add_entry_by_txt(cert_issuer, "C", MBSTRING_ASC, (unsigned char*)"AT", -1, -1, 0) + && X509_NAME_add_entry_by_txt(cert_issuer, "O", MBSTRING_ASC, (unsigned char*)"me", -1, -1, 0) + && X509_NAME_add_entry_by_txt(cert_issuer, "CN", MBSTRING_ASC, (unsigned char*)"localhost", -1, -1, 0))) { + + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to create issuer name."); + goto generate_cert_cleanup; + + } + + if (!X509_set_issuer_name(cert, cert_issuer)) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to set issuer name on certificate."); + goto generate_cert_cleanup; + } + + if (!X509_set_pubkey(cert, pkey)) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to set private key on certificate"); + goto generate_cert_cleanup; + } + + if (!X509_sign(cert, pkey, EVP_sha256())) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to sign certificate."); + goto generate_cert_cleanup; + } + + FILE* key_file = NULL; + key_file = fopen(_pkey_file_name, "wb"); + + if (!PEM_write_PrivateKey(key_file, pkey, EVP_des_ede3_cbc(), (unsigned char*)_passphrase, sizeof(_passphrase), NULL, NULL)) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to write private certificate to disk."); + fclose(key_file); + goto generate_cert_cleanup; + } + fclose(key_file); + + key_file = fopen(_cert_file_name, "wb"); + if (!PEM_write_X509(key_file, cert)) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to write certificate to disk."); + fclose(key_file); + goto generate_cert_cleanup; + } + fclose(key_file); + + success = true; + +generate_cert_cleanup: + + EVP_PKEY_free(pkey); + X509_free(cert); + X509_NAME_free(cert_issuer); + + return success; +} + +bool ssl_use_ppcerts(SSL_CTX* ctx) { + + if (access(_cert_file_name, R_OK) != 0 && access(_pkey_file_name, R_OK) != 0) { + log_log(LOG_NAME, LOG_INFO, "Could not access certificates - Generating new ones"); + bool res = ssl_generate_cert(); + if (!res) { + log_log(LOG_NAME, LOG_ERR, "Failed to generate certificates"); + return false; + } + } + + if (SSL_CTX_use_certificate_chain_file(ctx, _cert_file_name) != 1) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to load SSL certificate."); + return false; + } + + SSL_CTX_set_default_passwd_cb(ctx, _pem_passphrase_cb); + + if (SSL_CTX_use_PrivateKey_file(ctx, _pkey_file_name, SSL_FILETYPE_PEM) != 1) { + log_log_ssl_err(LOG_NAME, LOG_ERR, "Failed to load SSL Private Key."); + return false; + } + + SSL_CTX_set_default_passwd_cb(ctx, NULL); + + return true; +} + diff --git a/ssl.h b/ssl.h new file mode 100644 index 0000000..535c762 --- /dev/null +++ b/ssl.h @@ -0,0 +1,17 @@ +#ifndef _OAUTH2LIB_SSL_ +#define _OAUTH2LIB_SSL_ + +#include + +#include + +bool initialize_ssl(); + +SSL_CTX* ssl_get_ctx(); +EVP_MD* ssl_get_sha256(); + +bool ssl_generate_cert(); + +bool ssl_use_ppcerts(SSL_CTX* ctx); + +#endif //_OAUTH2LIB_SSL_ diff --git a/util/a.out b/util/a.out new file mode 100755 index 0000000..04e9f5b Binary files /dev/null and b/util/a.out differ diff --git a/util/sorted_str_set.c b/util/sorted_str_set.c new file mode 100644 index 0000000..b6b0d08 --- /dev/null +++ b/util/sorted_str_set.c @@ -0,0 +1,174 @@ +#include "sorted_str_set.h" + +#include +#include +#include + +#define SORTED_STR_SET_INITIAL_CAPACITY 8 +#define SORTED_STR_SET_GROWTH_FACTOR 1.5 + +static ssize_t _get(const sorted_str_set* set, const char* key, int* out_res) { + ssize_t index = 0; + ssize_t low = 0; + ssize_t upp = set->size - 1; + int res = 0; + + while (low <= upp) { + index = low + (upp-low)/2; + + res = strcmp(key, set->data[index].key); + if (res == 0) { + if (out_res) *out_res = res; + return index; + } + + if (res < 0) { + upp = index - 1; + } else { + low = index + 1; + } + } + + if (res > 0) { + ++index; + } + + return -1; +} + + +sorted_str_set* sorted_str_set_new(size_t element_size) { + + sorted_str_set* set = malloc(sizeof(sorted_str_set)); + if (!set) { return NULL; } + + struct sorted_str_set_pair* data = malloc(SORTED_STR_SET_INITIAL_CAPACITY * sizeof(struct sorted_str_set_pair)); + if (!data) { + free(set); + return NULL; + } + + *set = (sorted_str_set) { + .element_size = element_size, + .size = 0, + .capacity = SORTED_STR_SET_INITIAL_CAPACITY, + .data = data, + }; + return set; +} + +void sorted_str_set_free(sorted_str_set* set) { + assert(set); //, "set must not be NULL."); + for (int i = 0; i < set->size; ++i) { + free(set->data[i].key); + free(set->data[i].element); + } + free(set->data); + free(set); +} + +void* sorted_str_set_get(const sorted_str_set* set, const char* key) { + + assert(set); //, "set must not be null"); + assert(key); //, "key must not be null"); + + if (set->size == 0) { + return NULL; + } + + ssize_t i = _get(set, key, NULL); + + if (i == -1) { + return NULL; + } else { + return set->data[i].element; + } +} + +void* sorted_str_set_insert(sorted_str_set* set, const char* key, const void* element) { + + assert(set); //, "set must not be NULL"); + assert(key); //, "key must not be NULL"); + assert(element); //, "element must not be NULL"); + + size_t index = 0; + + if (set->size != 0) { + ssize_t low = 0; + ssize_t upp = set->size - 1; + int res = 0; + + while (low <= upp) { + index = low + (upp-low)/2; + + res = strcmp(key, set->data[index].key); + if (res == 0) { + return set->data[index].element; + } + + if (res < 0) { + upp = index - 1; + } else { + low = index + 1; + } + } + + if (res > 0) { + ++index; + } + } + + if (set->size == set->capacity) { + size_t new_cap = set->capacity * SORTED_STR_SET_GROWTH_FACTOR; + struct sorted_str_set_pair* tmp = reallocarray(set->data, new_cap, sizeof(struct sorted_str_set_pair)); + if (!tmp) { + return NULL; + } + + set->data = tmp; + set->capacity = new_cap; + } + + struct sorted_str_set_pair pair = { + .key = malloc(strlen(key) + 1), + .element = malloc(set->element_size), + }; + if (!pair.key || !pair.element) { + free(pair.key); + free(pair.element); + return NULL; + } + strcpy(pair.key, key); + memcpy(pair.element, element, set->element_size); + + if (set->size - index) { + memmove(&set->data[index + 1], &set->data[index], (set->size - index) * sizeof(struct sorted_str_set_pair)); + } + + set->data[index] = pair; + set->size++; + + return set->data[index].element; +} + +void sorted_str_set_remove(sorted_str_set* set, const char* key) { + + assert(set);//, "set must not be NULL"); + assert(key);//, "key must not be NULL"); + + ssize_t index = _get(set, key, NULL); + if (index == -1) { + return; + } + + struct sorted_str_set_pair* loc = &set->data[index]; + + free(loc->key); + free(loc->element); + + set->size--; + + memmove(loc, loc+1, (set->size - index) * sizeof(struct sorted_str_set_pair)); + +} + diff --git a/util/sorted_str_set.h b/util/sorted_str_set.h new file mode 100644 index 0000000..23184f4 --- /dev/null +++ b/util/sorted_str_set.h @@ -0,0 +1,69 @@ +#ifndef _OAUTH2LIB_UTIL_SORTED_SET_ +#define _OAUTH2LIB_UTIL_SORTED_SET_ + +#include + +struct sorted_str_set_pair { + char* key; + void* element; +}; + +typedef struct { + size_t element_size; + size_t size; + size_t capacity; + + struct sorted_str_set_pair* data; +} sorted_str_set; + + +/** + * @name sorted_str_set_new + * @description Allocates and returns a new instance of sorted_str_set. Must be freed with @sorted_str_set_free + * + * @param element_size the size of each element (not the key) + * @return the newly allocated set or NULL if allocation fails + */ +sorted_str_set* sorted_str_set_new(size_t element_size); + +/** + * @name sorted_str_set_free + * @description Frees all resources of an instance of sorted_str_set. Has to be called before the end of the lifetime of the variable. + * Does not free any memory possibly pointed to by element. + * + * @param set the set to be deallocated. Must not be NULL. + */ +void sorted_str_set_free(sorted_str_set* set); + +/** + * @name sorted_str_set_free + * @description Returns a pointer to the element associated with key within set + * + * @param set the set to be accessed. Must not be NULL. + * @param key the key whose value should be retrieved. Must be null-terminated. + * @return a pointer to the element associated with key or NULL if key is not in the set + */ +void* sorted_str_set_get(const sorted_str_set* set, const char* key); + +/** + * @name sorted_str_set_insert + * @description Inserts a new element associated with key into set and returns a pointer to the element. + * If key is already present in set, nothing is modified and a pointer to the existing element is returned. + * + * @param set the set to be accessed. Must not be NULL. + * @param key the key for the element. Must not be NULL. Must be null-terminated. + * @param element the element to be inserted. A copy of element is made. Is expected to be the size of set->element_size. Must not be NULL. + * @return a pointer to the newly inserted element or if the key is already existing, the element associated with key. + */ +void* sorted_str_set_insert(sorted_str_set* set, const char* key, const void* element); + +/** + * @name sorted_str_set_remove + * @description Removes the key and element associated by key from set. If the key is not present, does nothing. + * + * @param set the set to be accessed. Must not be NULL. + * @aram key the key whose element should be removed. Must not be NULL. + */ +void sorted_str_set_remove(sorted_str_set* set, const char* key); + +#endif //_OAUTH2LIB_UTIL_SORTED_SET_ diff --git a/util/test_sorted_str_set.c b/util/test_sorted_str_set.c new file mode 100644 index 0000000..7cc973c --- /dev/null +++ b/util/test_sorted_str_set.c @@ -0,0 +1,67 @@ +#include "sorted_set.h" + +#include +#include + +#define REF(dtype, value) ((dtype []){ (value) }) + +void print_set(const sorted_str_set* set) { + for (int i = 0; i < set->size; ++i) { + printf("\"%s\": %d\n", set->data[i].key, *(int*)set->data[i].element); + } +} + +int main() { + + sorted_str_set *set = sorted_str_set_new(sizeof(int)); + + printf("%p\n", sorted_str_set_get(set, "zero")); + + printf("-%d\n", strcmp("-one", "one")); + + sorted_str_set_insert(set, "three", REF(int, 3)); + sorted_str_set_insert(set, "one", REF(int, 1)); + sorted_str_set_insert(set, "-one", REF(int, -1)); + sorted_str_set_insert(set, "four", REF(int, 4)); + sorted_str_set_insert(set, "two", REF(int, 2)); + + print_set(set); + + printf("\nsize: %lu\n", set->size); + printf("%d\n", *(int*)sorted_str_set_get(set, "one")); + printf("%d\n", *(int*)sorted_str_set_get(set, "two")); + printf("%d\n", *(int*)sorted_str_set_get(set, "three")); + printf("%d\n", *(int*)sorted_str_set_get(set, "four")); + printf("%d\n", *(int*)sorted_str_set_get(set, "-one")); + + printf("\n%d should be 2\n", *(int*)sorted_str_set_insert(set, "two", REF(int, -2))); + printf("size: %lu (should be the same as before)\n", set->size); + + sorted_str_set_insert(set, "ten", REF(int, 10)); + sorted_str_set_insert(set, "eleven", REF(int, 11)); + sorted_str_set_insert(set, "negative four", REF(int, -4)); + sorted_str_set_insert(set, "fifty", REF(int, 50)); + + printf("%d\n", *(int*)sorted_str_set_get(set, "ten")); + printf("%d\n", *(int*)sorted_str_set_get(set, "eleven")); + printf("%d\n", *(int*)sorted_str_set_get(set, "negative four")); + printf("%d\n", *(int*)sorted_str_set_get(set, "fifty")); + + printf("size: %lu\n", set->size); + + print_set(set); + + sorted_str_set_remove(set, "three"); + sorted_str_set_remove(set, "one"); + sorted_str_set_remove(set, "three"); + sorted_str_set_remove(set, "four"); + sorted_str_set_remove(set, "four"); + + printf("size: %lu\n", set->size); + + print_set(set); + + sorted_str_set_free(set); + + return 0; +} diff --git a/util/thread_queue.c b/util/thread_queue.c new file mode 100644 index 0000000..bf03db3 --- /dev/null +++ b/util/thread_queue.c @@ -0,0 +1,204 @@ +#include "thread_queue.h" + +#include +#include +#include +#include +#include + + +#define MSGPOOL_SIZE 256 + +struct msglist { + struct threadmsg msg; + struct msglist *next; +}; + +static inline struct msglist *get_msglist(struct threadqueue *queue) +{ + struct msglist *tmp; + + if(queue->msgpool != NULL) { + tmp = queue->msgpool; + queue->msgpool = tmp->next; + queue->msgpool_length--; + } else { + tmp = malloc(sizeof *tmp); + } + + return tmp; +} + +static inline void release_msglist(struct threadqueue *queue,struct msglist *node) +{ + + if(queue->msgpool_length > ( queue->length/8 + MSGPOOL_SIZE)) { + free(node); + } else { + node->msg.data = NULL; + node->msg.msgtype = 0; + node->next = queue->msgpool; + queue->msgpool = node; + queue->msgpool_length++; + } + if(queue->msgpool_length > (queue->length/4 + MSGPOOL_SIZE*10)) { + struct msglist *tmp = queue->msgpool; + queue->msgpool = tmp->next; + free(tmp); + queue->msgpool_length--; + } +} + +int thread_queue_init(struct threadqueue *queue) +{ + int ret = 0; + if (queue == NULL) { + return EINVAL; + } + memset(queue, 0, sizeof(struct threadqueue)); + ret = pthread_cond_init(&queue->cond, NULL); + if (ret != 0) { + return ret; + } + + ret = pthread_mutex_init(&queue->mutex, NULL); + if (ret != 0) { + pthread_cond_destroy(&queue->cond); + return ret; + } + + return 0; + +} + +int thread_queue_add(struct threadqueue *queue, void *data, long msgtype) +{ + struct msglist *newmsg; + pthread_mutex_lock(&queue->mutex); + newmsg = get_msglist(queue); + if (newmsg == NULL) { + pthread_mutex_unlock(&queue->mutex); + return ENOMEM; + } + newmsg->msg.data = data; + newmsg->msg.msgtype = msgtype; + + newmsg->next = NULL; + if (queue->last == NULL) { + queue->last = newmsg; + queue->first = newmsg; + } else { + queue->last->next = newmsg; + queue->last = newmsg; + } + + if(queue->length == 0) + pthread_cond_broadcast(&queue->cond); + queue->length++; + pthread_mutex_unlock(&queue->mutex); + + return 0; + +} + +int thread_queue_get(struct threadqueue *queue, const struct timespec *timeout, struct threadmsg *msg) +{ + struct msglist *firstrec; + int ret = 0; + struct timespec abstimeout; + + if (queue == NULL || msg == NULL) { + return EINVAL; + } + if (timeout) { + struct timeval now; + + gettimeofday(&now, NULL); + abstimeout.tv_sec = now.tv_sec + timeout->tv_sec; + abstimeout.tv_nsec = (now.tv_usec * 1000) + timeout->tv_nsec; + if (abstimeout.tv_nsec >= 1000000000) { + abstimeout.tv_sec++; + abstimeout.tv_nsec -= 1000000000; + } + } + + pthread_mutex_lock(&queue->mutex); + + /* Will wait until awakened by a signal or broadcast */ + while (queue->first == NULL && ret != ETIMEDOUT) { //Need to loop to handle spurious wakeups + if (timeout) { + ret = pthread_cond_timedwait(&queue->cond, &queue->mutex, &abstimeout); + } else { + pthread_cond_wait(&queue->cond, &queue->mutex); + + } + } + if (ret == ETIMEDOUT) { + pthread_mutex_unlock(&queue->mutex); + return ret; + } + + firstrec = queue->first; + queue->first = queue->first->next; + queue->length--; + + if (queue->first == NULL) { + queue->last = NULL; // we know this since we hold the lock + queue->length = 0; + } + + + msg->data = firstrec->msg.data; + msg->msgtype = firstrec->msg.msgtype; + msg->qlength = queue->length; + + release_msglist(queue,firstrec); + pthread_mutex_unlock(&queue->mutex); + + return 0; +} + +//maybe caller should supply a callback for cleaning the elements ? +int thread_queue_cleanup(struct threadqueue *queue, int freedata) +{ + struct msglist *rec; + struct msglist *next; + struct msglist *recs[2]; + int ret,i; + if (queue == NULL) { + return EINVAL; + } + + pthread_mutex_lock(&queue->mutex); + recs[0] = queue->first; + recs[1] = queue->msgpool; + for(i = 0; i < 2 ; i++) { + rec = recs[i]; + while (rec) { + next = rec->next; + if (freedata) { + free(rec->msg.data); + } + free(rec); + rec = next; + } + } + + pthread_mutex_unlock(&queue->mutex); + ret = pthread_mutex_destroy(&queue->mutex); + pthread_cond_destroy(&queue->cond); + + return ret; + +} + +long thread_queue_length(struct threadqueue *queue) +{ + long counter; + // get the length properly + pthread_mutex_lock(&queue->mutex); + counter = queue->length; + pthread_mutex_unlock(&queue->mutex); + return counter; + +} diff --git a/util/thread_queue.h b/util/thread_queue.h new file mode 100644 index 0000000..3e196ab --- /dev/null +++ b/util/thread_queue.h @@ -0,0 +1,185 @@ +#ifndef _OAUTH2LIB_UTIL_THREAD_QUEUE_ +#define _OAUTH2LIB_UTIL_THREAD_QUEUE_ + +// https://stackoverflow.com/questions/4577961/pthread-synchronized-blocking-queue + +#include + +/** + * @defgroup ThreadQueue ThreadQueue + * + * Little API for waitable queues, typically used for passing messages + * between threads. + * + */ + +/** + * @mainpage + */ + +/** + * A thread message. + * + * @ingroup ThreadQueue + * + * This is used for passing to #thread_queue_get for retreive messages. + * the date is stored in the data member, the message type in the #msgtype. + * + * Typical: + * @code + * struct threadmsg; + * struct myfoo *foo; + * while(1) + * ret = thread_queue_get(&queue,NULL,&message); + * .. + * foo = msg.data; + * switch(msg.msgtype){ + * ... + * } + * } + * @endcode + * + */ +struct threadmsg{ + /** + * Holds the data. + */ + void *data; + /** + * Holds the messagetype + */ + long msgtype; + /** + * Holds the current queue lenght. Might not be meaningful if there's several readers + */ + long qlength; + +}; + + +/** + * A TthreadQueue + * + * @ingroup ThreadQueue + * + * You should threat this struct as opaque, never ever set/get any + * of the variables. You have been warned. + */ +struct threadqueue { +/** + * Length of the queue, never set this, never read this. + * Use #threadqueue_length to read it. + */ + long length; +/** + * Mutex for the queue, never touch. + */ + pthread_mutex_t mutex; +/** + * Condition variable for the queue, never touch. + */ + pthread_cond_t cond; +/** + * Internal pointers for the queue, never touch. + */ + struct msglist *first,*last; +/** + * Internal cache of msglists + */ + struct msglist *msgpool; +/** + * No. of elements in the msgpool + */ + long msgpool_length; +}; + +/** + * Initializes a queue. + * + * @ingroup ThreadQueue + * + * thread_queue_init initializes a new threadqueue. A new queue must always + * be initialized before it is used. + * + * @param queue Pointer to the queue that should be initialized + * @return 0 on success see pthread_mutex_init + */ +int thread_queue_init(struct threadqueue *queue); + +/** + * Adds a message to a queue + * + * @ingroup ThreadQueue + * + * thread_queue_add adds a "message" to the specified queue, a message + * is just a pointer to a anything of the users choice. Nothing is copied + * so the user must keep track on (de)allocation of the data. + * A message type is also specified, it is not used for anything else than + * given back when a message is retreived from the queue. + * + * @param queue Pointer to the queue on where the message should be added. + * @param data the "message". + * @param msgtype a long specifying the message type, choice of the user. + * @return 0 on succes ENOMEM if out of memory EINVAL if queue is NULL + */ +int thread_queue_add(struct threadqueue *queue, void *data, long msgtype); + +/** + * Gets a message from a queue + * + * @ingroup ThreadQueue + * + * thread_queue_get gets a message from the specified queue, it will block + * the caling thread untill a message arrives, or the (optional) timeout occurs. + * If timeout is NULL, there will be no timeout, and thread_queue_get will wait + * untill a message arrives. + * + * struct timespec is defined as: + * @code + * struct timespec { + * long tv_sec; // seconds + * long tv_nsec; // nanoseconds + * }; + * @endcode + * + * @param queue Pointer to the queue to wait on for a message. + * @param timeout timeout on how long to wait on a message + * @param msg pointer that is filled in with mesagetype and data + * + * @return 0 on success EINVAL if queue is NULL ETIMEDOUT if timeout occurs + */ +int thread_queue_get(struct threadqueue *queue, const struct timespec *timeout, struct threadmsg *msg); + + +/** + * Gets the length of a queue + * + * @ingroup ThreadQueue + * + * threadqueue_length returns the number of messages waiting in the queue + * + * @param queue Pointer to the queue for which to get the length + * @return the length(number of pending messages) in the queue + */ +long thread_queue_length( struct threadqueue *queue ); + +/** + * @ingroup ThreadQueue + * Cleans up the queue. + * + * threadqueue_cleanup cleans up and destroys the queue. + * This will remove all messages from a queue, and reset it. If + * freedata is != 0 free(3) will be called on all pending messages in the queue + * You cannot call this if there are someone currently adding or getting messages + * from the queue. + * After a queue have been cleaned, it cannot be used again untill #thread_queue_init + * has been called on the queue. + * + * @param queue Pointer to the queue that should be cleaned + * @param freedata set to nonzero if free(3) should be called on remaining + * messages + * @return 0 on success EINVAL if queue is NULL EBUSY if someone is holding any locks on the queue + */ +int thread_queue_cleanup(struct threadqueue *queue, int freedata); + +#endif //_OAUTH2LIB_UTIL_THREAD_QUEUE_