OAuth2Lib/pkce.c

146 lines
3.7 KiB
C

#include "pkce.h"
#include <openssl/sha.h>
#include <stddef.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/rand.h>
#include "base64.h"
#include "log.h"
#include "ssl.h"
#define MODULE_NAME "PKCE"
#define CODE_VERIFIER_LENGTH 32 // the length of the binary random string before base64url encoding
struct code_verifier_s {
size_t length;
char* key;
};
static inline void free_code_verifier(struct code_verifier_s* ptr) { free(ptr->key); }
static struct code_verifier_s generate_code_verifier() {
uint8_t code_buffer[CODE_VERIFIER_LENGTH];
int gen_result = RAND_priv_bytes(code_buffer, sizeof(code_buffer));
if (gen_result != 1) {
log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to generate code_verifier.\n");
return (struct code_verifier_s) { 0, NULL };
}
size_t encoded_size = 0;
char* encoded = base64url_encode(code_buffer, sizeof(code_buffer), &encoded_size);
if (!encoded) {
log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to encode code_verifier.\n");
return (struct code_verifier_s) { 0, NULL };
}
// trim trailing =\n
encoded_size -= 2;
encoded[encoded_size] = '\0';
return (struct code_verifier_s) { encoded_size, encoded };
}
static uint8_t* sha256_encode(const char* str, size_t str_len) {
EVP_MD_CTX* hash_ctx = EVP_MD_CTX_new();
if (!hash_ctx) {
log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to create Crypto context.\n");
return NULL;
}
EVP_MD* sha256_ctx = ssl_get_sha256();
if (!sha256_ctx) {
EVP_MD_CTX_free(hash_ctx);
return NULL;
}
if (!EVP_DigestInit_ex(hash_ctx, sha256_ctx, NULL)) {
log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to initialize sha256 digest.\n");
EVP_MD_CTX_free(hash_ctx);
return NULL;
}
if (!EVP_DigestUpdate(hash_ctx, str, str_len)) {
log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to set message to digest.\n");
EVP_MD_CTX_free(hash_ctx);
return NULL;
}
uint8_t *hash_buf = malloc(EVP_MD_get_size(sha256_ctx));
if (!hash_buf) {
log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to allocate memory for digest output.\n");
EVP_MD_CTX_free(hash_ctx);
return NULL;
}
unsigned int len = 0;
if (!EVP_DigestFinal_ex(hash_ctx, hash_buf, &len)) {
log_log_ssl_err("PKCE", LOG_DEBUG, "Failed to compute hash.\n");
EVP_MD_CTX_free(hash_ctx);
return NULL;
}
EVP_MD_CTX_free(hash_ctx);
return hash_buf;
}
char* _construct_get(char* server_url, char* client_id, char* redirect_uri, char* scope, char* state) {
static char msg_template[] = "GET %s?response_type=code&client_id=%s&state=%s&redirect_uri=%s HTTP/1.1\r\nHost: localhost\r\n\r\n";
size_t len = strlen(msg_template) - 4*2 + strlen(server_url)+strlen(client_id)+strlen(redirect_uri)+strlen(scope)+strlen(state) + 1;
char* buf = malloc(len);
if (!buf) { return NULL; }
if (snprintf(buf, len, msg_template, server_url, client_id, state, redirect_uri) < 0) {
free(buf);
return NULL;
}
return buf;
}
void do_pkce() {
log_log("PKCE", LOG_INFO, "generating code verifier");
struct code_verifier_s code_verifier = generate_code_verifier();
log_log("PKCE", LOG_DEBUG, "generated verifer: %s\n", code_verifier.key);
log_log("PKCE", LOG_INFO, "generating code challenge");
uint8_t* code_challenge_sha256 = sha256_encode(code_verifier.key, code_verifier.length);
if (!code_challenge_sha256) {
log_log("PKCE", LOG_ERR, "Failed to encode code challenge");
goto pkce_cleanup;
}
size_t code_challenge_length = 0;
char* code_challenge = base64url_encode(code_challenge_sha256, EVP_MD_get_size(ssl_get_sha256()), &code_challenge_length);
code_challenge_length -= 2;
code_challenge[code_challenge_length] = '\0';
log_log("PKCE", LOG_DEBUG, "generated code challenge: %s\n", code_challenge);
pkce_cleanup:
free_code_verifier(&code_verifier);
free(code_challenge);
}