From 88d25a37b48cb41f9e4559aa69e55bb5be0b0191 Mon Sep 17 00:00:00 2001 From: Victor Hiairrassary Date: Sun, 15 Dec 2024 23:18:37 +0100 Subject: [PATCH] Support prepared statements and parameters --- CMakeLists.txt | 8 +- src/http_handler/authentication.cpp | 55 ++++ src/http_handler/bindings.cpp | 69 ++++ src/http_handler/handler.cpp | 220 +++++++++++++ src/http_handler/response_serializer.cpp | 142 ++++++++ src/httpserver_extension.cpp | 309 +----------------- src/include/httpserver_extension.hpp | 6 +- .../httpserver_extension/http_handler.hpp | 11 + .../http_handler/authentication.hpp | 10 + .../http_handler/bindings.hpp | 15 + .../http_handler/common.hpp | 24 ++ .../http_handler/handler.hpp | 40 +++ .../http_handler/response_serializer.hpp | 11 + src/include/httpserver_extension/state.hpp | 28 ++ test/sql/auth.test | 129 ++++++++ test/sql/basics.test | 69 ++++ test/sql/quack.test | 23 -- test/sql/simple-get.test | 52 +++ 18 files changed, 895 insertions(+), 326 deletions(-) create mode 100644 src/http_handler/authentication.cpp create mode 100644 src/http_handler/bindings.cpp create mode 100644 src/http_handler/handler.cpp create mode 100644 src/http_handler/response_serializer.cpp create mode 100644 src/include/httpserver_extension/http_handler.hpp create mode 100644 src/include/httpserver_extension/http_handler/authentication.hpp create mode 100644 src/include/httpserver_extension/http_handler/bindings.hpp create mode 100644 src/include/httpserver_extension/http_handler/common.hpp create mode 100644 src/include/httpserver_extension/http_handler/handler.hpp create mode 100644 src/include/httpserver_extension/http_handler/response_serializer.hpp create mode 100644 src/include/httpserver_extension/state.hpp create mode 100644 test/sql/auth.test create mode 100644 test/sql/basics.test delete mode 100644 test/sql/quack.test create mode 100644 test/sql/simple-get.test diff --git a/CMakeLists.txt b/CMakeLists.txt index 435c3dc..4687238 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,12 +16,16 @@ include_directories( # Embed ./src/assets/index.html as a C++ header add_custom_command( OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/playground.hpp - COMMAND ${CMAKE_COMMAND} -P ${PROJECT_SOURCE_DIR}/embed.cmake ${PROJECT_SOURCE_DIR}/src/assets/index.html ${CMAKE_CURRENT_BINARY_DIR}/playground.hpp playgroundContent + COMMAND ${CMAKE_COMMAND} -P ${PROJECT_SOURCE_DIR}/embed.cmake ${PROJECT_SOURCE_DIR}/src/assets/index.html ${CMAKE_CURRENT_BINARY_DIR}/httpserver_extension/http_handler/playground.hpp playgroundContent DEPENDS ${PROJECT_SOURCE_DIR}/src/assets/index.html ) set(EXTENSION_SOURCES src/httpserver_extension.cpp + src/http_handler/authentication.cpp + src/http_handler/bindings.cpp + src/http_handler/handler.cpp + src/http_handler/response_serializer.cpp ${CMAKE_CURRENT_BINARY_DIR}/playground.hpp ) @@ -37,7 +41,9 @@ build_loadable_extension(${TARGET_NAME} " " ${EXTENSION_SOURCES}) include_directories(${OPENSSL_INCLUDE_DIR}) target_link_libraries(${LOADABLE_EXTENSION_NAME} duckdb_mbedtls ${OPENSSL_LIBRARIES}) +set_property(TARGET ${LOADABLE_EXTENSION_NAME} PROPERTY CXX_STANDARD 17) target_link_libraries(${EXTENSION_NAME} duckdb_mbedtls ${OPENSSL_LIBRARIES}) +set_property(TARGET ${EXTENSION_NAME} PROPERTY CXX_STANDARD 17) if(MINGW) set(WIN_LIBS crypt32 ws2_32 wsock32) diff --git a/src/http_handler/authentication.cpp b/src/http_handler/authentication.cpp new file mode 100644 index 0000000..4003bd3 --- /dev/null +++ b/src/http_handler/authentication.cpp @@ -0,0 +1,55 @@ +#include "httpserver_extension/http_handler/common.hpp" +#include "httpserver_extension/state.hpp" +#include +#include + +#define CPPHTTPLIB_OPENSSL_SUPPORT +#include "httplib.hpp" + +namespace duckdb_httpserver { + +// Base64 decoding function +static std::string base64_decode(const std::string &in) { + std::string out; + std::vector T(256, -1); + for (int i = 0; i < 64; i++) + T["ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[i]] = i; + + int val = 0, valb = -8; + for (unsigned char c : in) { + if (T[c] == -1) break; + val = (val << 6) + T[c]; + valb += 6; + if (valb >= 0) { + out.push_back(char((val >> valb) & 0xFF)); + valb -= 8; + } + } + return out; +} + +// Check authentication +void CheckAuthentication(const duckdb_httplib_openssl::Request& req) { + if (global_state.auth_token.empty()) { + return; // No authentication required if no token is set + } + + // Check for X-API-Key header + auto api_key = req.get_header_value("X-API-Key"); + if (!api_key.empty() && api_key == global_state.auth_token) { + return; + } + + // Check for Basic Auth + auto auth = req.get_header_value("Authorization"); + if (!auth.empty() && auth.compare(0, 6, "Basic ") == 0) { + std::string decoded_auth = base64_decode(auth.substr(6)); + if (decoded_auth == global_state.auth_token) { + return; + } + } + + throw HttpHandlerException(401, "Unauthorized"); +} + +} // namespace duckdb_httpserver diff --git a/src/http_handler/bindings.cpp b/src/http_handler/bindings.cpp new file mode 100644 index 0000000..43953e9 --- /dev/null +++ b/src/http_handler/bindings.cpp @@ -0,0 +1,69 @@ +#include "httpserver_extension/http_handler/common.hpp" +#include "duckdb.hpp" +#include "yyjson.hpp" +#include + +#define CPPHTTPLIB_OPENSSL_SUPPORT +#include "httplib.hpp" + +using namespace duckdb; +using namespace duckdb_yyjson; + +namespace duckdb_httpserver { + +static BoundParameterData ExtractQueryParameter(const std::string& key, yyjson_val* parameterVal) { + if (!yyjson_is_obj(parameterVal)) { + throw HttpHandlerException(400, "The parameter `" + key + "` must be an object"); + } + + auto typeVal = yyjson_obj_get(parameterVal, "type"); + if (!typeVal) { + throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `type` field"); + } + if (!yyjson_is_str(typeVal)) { + throw HttpHandlerException(400, "The field `type` for the parameter `" + key + "` must be a string"); + } + auto type = std::string(yyjson_get_str(typeVal)); + + auto valueVal = yyjson_obj_get(parameterVal, "value"); + if (!valueVal) { + throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `value` field"); + } + + if (type == "TEXT") { + if (!yyjson_is_str(valueVal)) { + throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a string"); + } + + return BoundParameterData(Value(yyjson_get_str(valueVal))); + } + else if (type == "BOOLEAN") { + if (!yyjson_is_bool(valueVal)) { + throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a boolean"); + } + + return BoundParameterData(Value(bool(yyjson_get_bool(valueVal)))); + } + + throw HttpHandlerException(400, "Unsupported type " + type + " the parameter `" + key + "`"); +} + +case_insensitive_map_t ExtractQueryParameters(yyjson_val* parametersVal) { + if (!parametersVal || !yyjson_is_obj(parametersVal)) { + throw HttpHandlerException(400, "The `parameters` field must be an object"); + } + + case_insensitive_map_t named_values; + + size_t idx, max; + yyjson_val *parameterKeyVal, *parameterVal; + yyjson_obj_foreach(parametersVal, idx, max, parameterKeyVal, parameterVal) { + auto parameterKeyString = std::string(yyjson_get_str(parameterKeyVal)); + + named_values[parameterKeyString] = ExtractQueryParameter(parameterKeyString, parameterVal); + } + + return named_values; +} + +} // namespace duckdb_httpserver diff --git a/src/http_handler/handler.cpp b/src/http_handler/handler.cpp new file mode 100644 index 0000000..4b2ce83 --- /dev/null +++ b/src/http_handler/handler.cpp @@ -0,0 +1,220 @@ +#include "httpserver_extension/http_handler/authentication.hpp" +#include "httpserver_extension/http_handler/bindings.hpp" +#include "httpserver_extension/http_handler/common.hpp" +#include "httpserver_extension/http_handler/handler.hpp" +#include "httpserver_extension/http_handler/playground.hpp" +#include "httpserver_extension/http_handler/response_serializer.hpp" +#include "httpserver_extension/state.hpp" +#include "duckdb.hpp" +#include "yyjson.hpp" + +#include +#include + +#define CPPHTTPLIB_OPENSSL_SUPPORT +#include "httplib.hpp" + +using namespace duckdb; +using namespace duckdb_yyjson; + +namespace duckdb_httpserver { + +// Handle both GET and POST requests +void HttpHandler(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) { + try { + // CORS allow + res.set_header("Access-Control-Allow-Origin", "*"); + res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT"); + res.set_header("Access-Control-Allow-Headers", "*"); + res.set_header("Access-Control-Allow-Credentials", "true"); + res.set_header("Access-Control-Max-Age", "86400"); + + // Handle preflight OPTIONS request + if (req.method == "OPTIONS") { + res.status = 204; // No content + return; + } + + CheckAuthentication(req); + + auto queryApiParameters = ExtractQueryApiParameters(req); + + if (!queryApiParameters.sqlQueryOpt.has_value()) { + res.status = 200; + res.set_content(reinterpret_cast(playgroundContent), sizeof(playgroundContent), "text/html"); + return; + } + + if (!global_state.db_instance) { + throw IOException("Database instance not initialized"); + } + + auto start = std::chrono::system_clock::now(); + auto result = ExecuteQuery(req, queryApiParameters); + auto end = std::chrono::system_clock::now(); + auto elapsed = std::chrono::duration_cast(end - start); + + QueryExecStats stats { + static_cast(elapsed.count()) / 1000, + 0, + 0 + }; + + // Format output + if (queryApiParameters.outputFormat == OutputFormat::Ndjson) { + std::string json_output = ConvertResultToNDJSON(*result); + res.set_content(json_output, "application/x-ndjson"); + } + else { + auto json_output = ConvertResultToJSON(*result, stats); + res.set_content(json_output, "application/json"); + } + } + catch (const HttpHandlerException& ex) { + res.status = ex.status; + res.set_content(ex.message, "text/plain"); + } + catch (const std::exception& ex) { + res.status = 500; + std::string error_message = "Code: 59, e.displayText() = DB::Exception: " + std::string(ex.what()); + res.set_content(error_message, "text/plain"); + } +} + +// Execute query (optionally using a prepared statement) +std::unique_ptr ExecuteQuery( + const duckdb_httplib_openssl::Request& req, + const QueryApiParameters& queryApiParameters +) { + Connection con(*global_state.db_instance); + std::unique_ptr result; + auto query = queryApiParameters.sqlQueryOpt.value(); + + auto use_prepared_stmt = + queryApiParameters.sqlParametersOpt.has_value() && + queryApiParameters.sqlParametersOpt.value().empty() == false; + + if (use_prepared_stmt) { + auto prepared_stmt = con.Prepare(query); + if (prepared_stmt->HasError()) { + throw HttpHandlerException(500, prepared_stmt->GetError()); + } + + auto named_values = queryApiParameters.sqlParametersOpt.value(); + + auto prepared_stmt_result = prepared_stmt->Execute(named_values); + D_ASSERT(prepared_stmt_result->type == QueryResultType::STREAM_RESULT); + result = unique_ptr_cast(std::move(prepared_stmt_result))->Materialize(); + } else { + result = con.Query(query); + } + + if (result->HasError()) { + throw HttpHandlerException(500, result->GetError()); + } + + return result; +} + +QueryApiParameters ExtractQueryApiParameters(const duckdb_httplib_openssl::Request& req) { + if (req.method == "POST" && req.has_header("Content-Type") && req.get_header_value("Content-Type") == "application/json") { + return ExtractQueryApiParametersComplex(req); + } + else { + return QueryApiParameters { + ExtractSqlQuerySimple(req), + std::nullopt, + ExtractOutputFormatSimple(req), + }; + } +} + +std::optional ExtractSqlQuerySimple(const duckdb_httplib_openssl::Request& req) { + // Check if the query is in the URL parameters + if (req.has_param("query")) { + return req.get_param_value("query"); + } + else if (req.has_param("q")) { + return req.get_param_value("q"); + } + + // If not in URL, and it's a POST request, check the body + else if (req.method == "POST" && !req.body.empty()) { + return req.body; + } + + return std::nullopt; +} + +OutputFormat ExtractOutputFormatSimple(const duckdb_httplib_openssl::Request& req) { + // Check for format in URL parameter or header + if (req.has_param("default_format")) { + return ParseOutputFormat(req.get_param_value("default_format")); + } + else if (req.has_header("X-ClickHouse-Format")) { + return ParseOutputFormat(req.get_header_value("X-ClickHouse-Format")); + } + else if (req.has_header("format")) { + return ParseOutputFormat(req.get_header_value("format")); + } + else { + return OutputFormat::Ndjson; + } +} + +OutputFormat ParseOutputFormat(const std::string& formatStr) { + if (formatStr == "JSONEachRow" || formatStr == "ndjson" || formatStr == "jsonl") { + return OutputFormat::Ndjson; + } + else if (formatStr == "JSONCompact") { + return OutputFormat::Json; + } + else { + throw HttpHandlerException(400, "Unknown format"); + } +} + +QueryApiParameters ExtractQueryApiParametersComplex(const duckdb_httplib_openssl::Request& req) { + yyjson_doc *bodyDoc = nullptr; + + try { + auto bodyJson = req.body; + auto bodyJsonCStr = bodyJson.c_str(); + bodyDoc = yyjson_read(bodyJsonCStr, strlen(bodyJsonCStr), 0); + + return ExtractQueryApiParametersComplexImpl(bodyDoc); + } + catch (const std::exception& exception) { + yyjson_doc_free(bodyDoc); + throw; + } +} + +QueryApiParameters ExtractQueryApiParametersComplexImpl(yyjson_doc* bodyDoc) { + if (!bodyDoc) { + throw HttpHandlerException(400, "Unable to parse the request body"); + } + + auto bodyRoot = yyjson_doc_get_root(bodyDoc); + if (!yyjson_is_obj(bodyRoot)) { + throw HttpHandlerException(400, "The request body must be an object"); + } + + auto queryVal = yyjson_obj_get(bodyRoot, "query"); + if (!queryVal || !yyjson_is_str(queryVal)) { + throw HttpHandlerException(400, "The `query` field must be a string"); + } + + auto formatVal = yyjson_obj_get(bodyRoot, "format"); + if (!formatVal || !yyjson_is_str(formatVal)) { + throw HttpHandlerException(400, "The `format` field must be a string"); + } + + return QueryApiParameters { + std::string(yyjson_get_str(queryVal)), + ExtractQueryParameters(yyjson_obj_get(bodyRoot, "parameters")), + ParseOutputFormat(std::string(yyjson_get_str(formatVal))), + }; +} + +} // namespace duckdb_httpserver diff --git a/src/http_handler/response_serializer.cpp b/src/http_handler/response_serializer.cpp new file mode 100644 index 0000000..d418cf8 --- /dev/null +++ b/src/http_handler/response_serializer.cpp @@ -0,0 +1,142 @@ +#include "httpserver_extension/http_handler/common.hpp" +#include "duckdb.hpp" + +#include "yyjson.hpp" +#include + +using namespace duckdb; +using namespace duckdb_yyjson; + +namespace duckdb_httpserver { + +static std::string GetColumnType(MaterializedQueryResult& result, idx_t column) { + if (result.RowCount() == 0) { + return "String"; + } + switch (result.types[column].id()) { + case LogicalTypeId::FLOAT: + return "Float"; + case LogicalTypeId::DOUBLE: + return "Double"; + case LogicalTypeId::INTEGER: + return "Int32"; + case LogicalTypeId::BIGINT: + return "Int64"; + case LogicalTypeId::UINTEGER: + return "UInt32"; + case LogicalTypeId::UBIGINT: + return "UInt64"; + case LogicalTypeId::VARCHAR: + return "String"; + case LogicalTypeId::TIME: + return "DateTime"; + case LogicalTypeId::DATE: + return "Date"; + case LogicalTypeId::TIMESTAMP: + return "DateTime"; + case LogicalTypeId::BOOLEAN: + return "Int8"; + default: + return "String"; + } + return "String"; +} + +// Convert the query result to JSON format +std::string ConvertResultToJSON(MaterializedQueryResult& result, QueryExecStats& stats) { + auto doc = yyjson_mut_doc_new(nullptr); + auto root = yyjson_mut_obj(doc); + yyjson_mut_doc_set_root(doc, root); + + // Add meta information + auto meta_array = yyjson_mut_arr(doc); + for (idx_t col = 0; col < result.ColumnCount(); ++col) { + auto column_obj = yyjson_mut_obj(doc); + yyjson_mut_obj_add_str(doc, column_obj, "name", result.ColumnName(col).c_str()); + yyjson_mut_arr_append(meta_array, column_obj); + std::string tp(GetColumnType(result, col)); + yyjson_mut_obj_add_strcpy(doc, column_obj, "type", tp.c_str()); + } + yyjson_mut_obj_add_val(doc, root, "meta", meta_array); + + // Add data + auto data_array = yyjson_mut_arr(doc); + for (idx_t row = 0; row < result.RowCount(); ++row) { + auto row_array = yyjson_mut_arr(doc); + for (idx_t col = 0; col < result.ColumnCount(); ++col) { + Value value = result.GetValue(col, row); + if (value.IsNull()) { + yyjson_mut_arr_append(row_array, yyjson_mut_null(doc)); + } else { + std::string value_str = value.ToString(); + yyjson_mut_arr_append(row_array, yyjson_mut_strncpy(doc, value_str.c_str(), value_str.length())); + } + } + yyjson_mut_arr_append(data_array, row_array); + } + yyjson_mut_obj_add_val(doc, root, "data", data_array); + + // Add row count + yyjson_mut_obj_add_int(doc, root, "rows", result.RowCount()); + + //"statistics":{"elapsed":0.00031403,"rows_read":1,"bytes_read":0}} + auto stat_obj = yyjson_mut_obj_add_obj(doc, root, "statistics"); + yyjson_mut_obj_add_real(doc, stat_obj, "elapsed", stats.elapsed_sec); + yyjson_mut_obj_add_int(doc, stat_obj, "rows_read", stats.read_rows); + yyjson_mut_obj_add_int(doc, stat_obj, "bytes_read", stats.read_bytes); + + // Write to string + auto data = yyjson_mut_write(doc, 0, nullptr); + if (!data) { + yyjson_mut_doc_free(doc); + throw InternalException("Failed to render the result as JSON, yyjson failed"); + } + + std::string json_output(data); + free(data); + yyjson_mut_doc_free(doc); + return json_output; +} + +// Convert the query result to NDJSON (JSONEachRow) format +std::string ConvertResultToNDJSON(MaterializedQueryResult& result) { + std::string ndjson_output; + + for (idx_t row = 0; row < result.RowCount(); ++row) { + // Create a new JSON document for each row + auto doc = yyjson_mut_doc_new(nullptr); + auto root = yyjson_mut_obj(doc); + yyjson_mut_doc_set_root(doc, root); + + for (idx_t col = 0; col < result.ColumnCount(); ++col) { + Value value = result.GetValue(col, row); + const char* column_name = result.ColumnName(col).c_str(); + + // Handle null values and add them to the JSON object + if (value.IsNull()) { + yyjson_mut_obj_add_null(doc, root, column_name); + } else { + // Convert value to string and add it to the JSON object + std::string value_str = value.ToString(); + yyjson_mut_obj_add_strncpy(doc, root, column_name, value_str.c_str(), value_str.length()); + } + } + + char *json_line = yyjson_mut_write(doc, 0, nullptr); + if (!json_line) { + yyjson_mut_doc_free(doc); + throw InternalException("Failed to render a row as JSON, yyjson failed"); + } + + ndjson_output += json_line; + ndjson_output += "\n"; + + // Free allocated memory for this row + free(json_line); + yyjson_mut_doc_free(doc); + } + + return ndjson_output; +} + +} // namespace duckdb_httpserver diff --git a/src/httpserver_extension.cpp b/src/httpserver_extension.cpp index 65a9aa0..f43859d 100644 --- a/src/httpserver_extension.cpp +++ b/src/httpserver_extension.cpp @@ -1,5 +1,8 @@ #define DUCKDB_EXTENSION_MAIN #include "httpserver_extension.hpp" +#include "httpserver_extension/http_handler.hpp" +#include "httpserver_extension/state.hpp" + #include "duckdb.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/string_util.hpp" @@ -17,303 +20,13 @@ #include #endif -#define CPPHTTPLIB_OPENSSL_SUPPORT -#include "httplib.hpp" -#include "yyjson.hpp" - -#include "playground.hpp" - -using namespace duckdb_yyjson; // NOLINT - -namespace duckdb { - -struct HttpServerState { - std::unique_ptr server; - std::unique_ptr server_thread; - std::atomic is_running; - DatabaseInstance* db_instance; - unique_ptr allocator; - std::string auth_token; - - HttpServerState() : is_running(false), db_instance(nullptr) {} -}; - -static HttpServerState global_state; - -std::string GetColumnType(MaterializedQueryResult &result, idx_t column) { - if (result.RowCount() == 0) { - return "String"; - } - switch (result.types[column].id()) { - case LogicalTypeId::FLOAT: - return "Float"; - case LogicalTypeId::DOUBLE: - return "Double"; - case LogicalTypeId::INTEGER: - return "Int32"; - case LogicalTypeId::BIGINT: - return "Int64"; - case LogicalTypeId::UINTEGER: - return "UInt32"; - case LogicalTypeId::UBIGINT: - return "UInt64"; - case LogicalTypeId::VARCHAR: - return "String"; - case LogicalTypeId::TIME: - return "DateTime"; - case LogicalTypeId::DATE: - return "Date"; - case LogicalTypeId::TIMESTAMP: - return "DateTime"; - case LogicalTypeId::BOOLEAN: - return "Int8"; - default: - return "String"; - } - return "String"; -} - -struct ReqStats { - float elapsed_sec; - int64_t read_bytes; - int64_t read_rows; -}; - -// Convert the query result to JSON format -static std::string ConvertResultToJSON(MaterializedQueryResult &result, ReqStats &req_stats) { - auto doc = yyjson_mut_doc_new(nullptr); - auto root = yyjson_mut_obj(doc); - yyjson_mut_doc_set_root(doc, root); - // Add meta information - auto meta_array = yyjson_mut_arr(doc); - for (idx_t col = 0; col < result.ColumnCount(); ++col) { - auto column_obj = yyjson_mut_obj(doc); - yyjson_mut_obj_add_str(doc, column_obj, "name", result.ColumnName(col).c_str()); - yyjson_mut_arr_append(meta_array, column_obj); - std::string tp(GetColumnType(result, col)); - yyjson_mut_obj_add_strcpy(doc, column_obj, "type", tp.c_str()); - } - yyjson_mut_obj_add_val(doc, root, "meta", meta_array); - - // Add data - auto data_array = yyjson_mut_arr(doc); - for (idx_t row = 0; row < result.RowCount(); ++row) { - auto row_array = yyjson_mut_arr(doc); - for (idx_t col = 0; col < result.ColumnCount(); ++col) { - Value value = result.GetValue(col, row); - if (value.IsNull()) { - yyjson_mut_arr_append(row_array, yyjson_mut_null(doc)); - } else { - std::string value_str = value.ToString(); - yyjson_mut_arr_append(row_array, yyjson_mut_strncpy(doc, value_str.c_str(), value_str.length())); - } - } - yyjson_mut_arr_append(data_array, row_array); - } - yyjson_mut_obj_add_val(doc, root, "data", data_array); - - // Add row count - yyjson_mut_obj_add_int(doc, root, "rows", result.RowCount()); - //"statistics":{"elapsed":0.00031403,"rows_read":1,"bytes_read":0}} - auto stat_obj = yyjson_mut_obj_add_obj(doc, root, "statistics"); - yyjson_mut_obj_add_real(doc, stat_obj, "elapsed", req_stats.elapsed_sec); - yyjson_mut_obj_add_int(doc, stat_obj, "rows_read", req_stats.read_rows); - yyjson_mut_obj_add_int(doc, stat_obj, "bytes_read", req_stats.read_bytes); - // Write to string - auto data = yyjson_mut_write(doc, 0, nullptr); - if (!data) { - yyjson_mut_doc_free(doc); - throw InternalException("Failed to render the result as JSON, yyjson failed"); - } - - std::string json_output(data); - free(data); - yyjson_mut_doc_free(doc); - return json_output; -} - -// New: Base64 decoding function -std::string base64_decode(const std::string &in) { - std::string out; - std::vector T(256, -1); - for (int i = 0; i < 64; i++) - T["ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[i]] = i; - - int val = 0, valb = -8; - for (unsigned char c : in) { - if (T[c] == -1) break; - val = (val << 6) + T[c]; - valb += 6; - if (valb >= 0) { - out.push_back(char((val >> valb) & 0xFF)); - valb -= 8; - } - } - return out; -} - -// Auth Check -bool IsAuthenticated(const duckdb_httplib_openssl::Request& req) { - if (global_state.auth_token.empty()) { - return true; // No authentication required if no token is set - } - - // Check for X-API-Key header - auto api_key = req.get_header_value("X-API-Key"); - if (!api_key.empty() && api_key == global_state.auth_token) { - return true; - } - - // Check for Basic Auth - auto auth = req.get_header_value("Authorization"); - if (!auth.empty() && auth.compare(0, 6, "Basic ") == 0) { - std::string decoded_auth = base64_decode(auth.substr(6)); - if (decoded_auth == global_state.auth_token) { - return true; - } - } - - return false; -} - -// Convert the query result to NDJSON (JSONEachRow) format -static std::string ConvertResultToNDJSON(MaterializedQueryResult &result) { - std::string ndjson_output; - - for (idx_t row = 0; row < result.RowCount(); ++row) { - // Create a new JSON document for each row - auto doc = yyjson_mut_doc_new(nullptr); - auto root = yyjson_mut_obj(doc); - yyjson_mut_doc_set_root(doc, root); - - for (idx_t col = 0; col < result.ColumnCount(); ++col) { - Value value = result.GetValue(col, row); - const char* column_name = result.ColumnName(col).c_str(); - - // Handle null values and add them to the JSON object - if (value.IsNull()) { - yyjson_mut_obj_add_null(doc, root, column_name); - } else { - // Convert value to string and add it to the JSON object - std::string value_str = value.ToString(); - yyjson_mut_obj_add_strncpy(doc, root, column_name, value_str.c_str(), value_str.length()); - } - } - - char *json_line = yyjson_mut_write(doc, 0, nullptr); - if (!json_line) { - yyjson_mut_doc_free(doc); - throw InternalException("Failed to render a row as JSON, yyjson failed"); - } - - ndjson_output += json_line; - ndjson_output += "\n"; - - // Free allocated memory for this row - free(json_line); - yyjson_mut_doc_free(doc); - } - - return ndjson_output; +namespace duckdb_httpserver { + duckdb_httpserver::State global_state; } -// Handle both GET and POST requests -void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) { - std::string query; - - // Check authentication - if (!IsAuthenticated(req)) { - res.status = 401; - res.set_content("Unauthorized", "text/plain"); - return; - } - - // CORS allow - res.set_header("Access-Control-Allow-Origin", "*"); - res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT"); - res.set_header("Access-Control-Allow-Headers", "*"); - res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Max-Age", "86400"); - - // Handle preflight OPTIONS request - if (req.method == "OPTIONS") { - res.status = 204; // No content - return; - } - - // Check if the query is in the URL parameters - if (req.has_param("query")) { - query = req.get_param_value("query"); - } - else if (req.has_param("q")) { - query = req.get_param_value("q"); - } - // If not in URL, and it's a POST request, check the body - else if (req.method == "POST" && !req.body.empty()) { - query = req.body; - } - // If no query found, return an error - else { - res.status = 200; - res.set_content(reinterpret_cast(playgroundContent), "text/html"); - return; - } - - // Set default format to JSONCompact - std::string format = "JSONEachRow"; - - // Check for format in URL parameter or header - if (req.has_param("default_format")) { - format = req.get_param_value("default_format"); - } else if (req.has_header("X-ClickHouse-Format")) { - format = req.get_header_value("X-ClickHouse-Format"); - } else if (req.has_header("format")) { - format = req.get_header_value("format"); - } - - try { - if (!global_state.db_instance) { - throw IOException("Database instance not initialized"); - } - - Connection con(*global_state.db_instance); - auto start = std::chrono::system_clock::now(); - auto result = con.Query(query); - auto end = std::chrono::system_clock::now(); - auto elapsed = std::chrono::duration_cast(end - start); +using namespace duckdb_httpserver; - if (result->HasError()) { - res.status = 500; - res.set_content(result->GetError(), "text/plain"); - return; - } - - - ReqStats stats{ - static_cast(elapsed.count()) / 1000, - 0, - 0 - }; - - // Format Options - if (format == "JSONEachRow") { - std::string json_output = ConvertResultToNDJSON(*result); - res.set_content(json_output, "application/x-ndjson"); - } else if (format == "JSONCompact") { - std::string json_output = ConvertResultToJSON(*result, stats); - res.set_content(json_output, "application/json"); - } else { - // Default to NDJSON for DuckDB's own queries - std::string json_output = ConvertResultToNDJSON(*result); - res.set_content(json_output, "application/x-ndjson"); - } - - } catch (const Exception& ex) { - res.status = 500; - std::string error_message = "Code: 59, e.displayText() = DB::Exception: " + std::string(ex.what()); - res.set_content(error_message, "text/plain"); - } -} +namespace duckdb { void HttpServerStart(DatabaseInstance& db, string_t host, int32_t port, string_t auth = string_t()) { if (global_state.is_running) { @@ -325,9 +38,9 @@ void HttpServerStart(DatabaseInstance& db, string_t host, int32_t port, string_t global_state.is_running = true; global_state.auth_token = auth.GetString(); - // Custom basepath, defaults to root / + // Custom basepath, defaults to root / const char* base_path_env = std::getenv("DUCKDB_HTTPSERVER_BASEPATH"); - std::string base_path = "/"; + std::string base_path = "/"; if (base_path_env && base_path_env[0] == '/' && strlen(base_path_env) > 1) { base_path = std::string(base_path_env); @@ -349,8 +62,8 @@ void HttpServerStart(DatabaseInstance& db, string_t host, int32_t port, string_t global_state.allocator = make_uniq(); // Handle GET and POST requests - global_state.server->Get(base_path, HandleHttpRequest); - global_state.server->Post(base_path, HandleHttpRequest); + global_state.server->Get(base_path, HttpHandler); + global_state.server->Post(base_path, HttpHandler); // Health check endpoint global_state.server->Get("/ping", [](const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) { diff --git a/src/include/httpserver_extension.hpp b/src/include/httpserver_extension.hpp index 432d1c0..ac7af74 100644 --- a/src/include/httpserver_extension.hpp +++ b/src/include/httpserver_extension.hpp @@ -5,15 +5,13 @@ namespace duckdb { -class HttpserverExtension : public Extension { +struct HttpserverExtension: public Extension { public: void Load(DuckDB &db) override; std::string Name() override; - std::string Version() const override; + std::string Version() const override; }; -// Static server state declarations -struct HttpServerState; void HttpServerStart(DatabaseInstance& db, string_t host, int32_t port); void HttpServerStop(); diff --git a/src/include/httpserver_extension/http_handler.hpp b/src/include/httpserver_extension/http_handler.hpp new file mode 100644 index 0000000..c06346b --- /dev/null +++ b/src/include/httpserver_extension/http_handler.hpp @@ -0,0 +1,11 @@ +#pragma once + +#define CPPHTTPLIB_OPENSSL_SUPPORT +#include "httplib.hpp" + +namespace duckdb_httpserver { + +// HTTP handler +void HttpHandler(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res); + +} // namespace duckdb_httpserver diff --git a/src/include/httpserver_extension/http_handler/authentication.hpp b/src/include/httpserver_extension/http_handler/authentication.hpp new file mode 100644 index 0000000..f41a487 --- /dev/null +++ b/src/include/httpserver_extension/http_handler/authentication.hpp @@ -0,0 +1,10 @@ +#pragma once + +#define CPPHTTPLIB_OPENSSL_SUPPORT +#include "httplib.hpp" + +namespace duckdb_httpserver { + +void CheckAuthentication(const duckdb_httplib_openssl::Request& req); + +} // namespace duckdb_httpserver diff --git a/src/include/httpserver_extension/http_handler/bindings.hpp b/src/include/httpserver_extension/http_handler/bindings.hpp new file mode 100644 index 0000000..42eca57 --- /dev/null +++ b/src/include/httpserver_extension/http_handler/bindings.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "duckdb.hpp" +#include "yyjson.hpp" + +#define CPPHTTPLIB_OPENSSL_SUPPORT +#include "httplib.hpp" + +namespace duckdb_httpserver { + +duckdb::case_insensitive_map_t ExtractQueryParameters( + duckdb_yyjson::yyjson_val* parametersVal +); + +} // namespace duckdb_httpserver diff --git a/src/include/httpserver_extension/http_handler/common.hpp b/src/include/httpserver_extension/http_handler/common.hpp new file mode 100644 index 0000000..2068362 --- /dev/null +++ b/src/include/httpserver_extension/http_handler/common.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +#include + +namespace duckdb_httpserver { + +// Used to have an easy to read control flow +struct HttpHandlerException: public std::exception { + int status; + std::string message; + + HttpHandlerException(int status, const std::string& message) : message(message), status(status) {} +}; + +// Statistics associated to the SQL query execution +struct QueryExecStats { + float elapsed_sec; + int64_t read_bytes; + int64_t read_rows; +}; + +} // namespace duckdb_httpserver diff --git a/src/include/httpserver_extension/http_handler/handler.hpp b/src/include/httpserver_extension/http_handler/handler.hpp new file mode 100644 index 0000000..d501d84 --- /dev/null +++ b/src/include/httpserver_extension/http_handler/handler.hpp @@ -0,0 +1,40 @@ +#pragma once + +#include "duckdb.hpp" +#include "yyjson.hpp" +#include + +#define CPPHTTPLIB_OPENSSL_SUPPORT +#include "httplib.hpp" + +namespace duckdb_httpserver { + +enum class OutputFormat { + Ndjson, + Json, +}; + +struct QueryApiParameters { + std::optional sqlQueryOpt; + std::optional> sqlParametersOpt; + OutputFormat outputFormat; +}; + +std::unique_ptr ExecuteQuery( + const duckdb_httplib_openssl::Request& req, + const QueryApiParameters& queryApiParameters +); + +QueryApiParameters ExtractQueryApiParameters(const duckdb_httplib_openssl::Request& req); + +QueryApiParameters ExtractQueryApiParametersComplex(const duckdb_httplib_openssl::Request& req); + +QueryApiParameters ExtractQueryApiParametersComplexImpl(duckdb_yyjson::yyjson_doc* bodyDoc); + +std::optional ExtractSqlQuerySimple(const duckdb_httplib_openssl::Request& req); + +OutputFormat ExtractOutputFormatSimple(const duckdb_httplib_openssl::Request& req); + +OutputFormat ParseOutputFormat(const std::string& formatStr); + +} // namespace duckdb_httpserver diff --git a/src/include/httpserver_extension/http_handler/response_serializer.hpp b/src/include/httpserver_extension/http_handler/response_serializer.hpp new file mode 100644 index 0000000..66f21b0 --- /dev/null +++ b/src/include/httpserver_extension/http_handler/response_serializer.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include "httpserver_extension/http_handler/common.hpp" +#include "duckdb.hpp" + +namespace duckdb_httpserver { + +std::string ConvertResultToJSON(duckdb::MaterializedQueryResult& result, QueryExecStats& stats); +std::string ConvertResultToNDJSON(duckdb::MaterializedQueryResult& result); + +} // namespace duckdb_httpserver diff --git a/src/include/httpserver_extension/state.hpp b/src/include/httpserver_extension/state.hpp new file mode 100644 index 0000000..2f0b70a --- /dev/null +++ b/src/include/httpserver_extension/state.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include "duckdb.hpp" +#include "duckdb/common/allocator.hpp" +#include +#include +#include +#include + +#define CPPHTTPLIB_OPENSSL_SUPPORT +#include "httplib.hpp" + +namespace duckdb_httpserver { + +struct State { + std::unique_ptr server; + std::unique_ptr server_thread; + std::atomic is_running; + duckdb::DatabaseInstance* db_instance; + std::unique_ptr allocator; + std::string auth_token; + + State() : is_running(false), db_instance(nullptr) {} +}; + +extern State global_state; + +} // namespace duckdb_httpserver diff --git a/test/sql/auth.test b/test/sql/auth.test new file mode 100644 index 0000000..96d25f6 --- /dev/null +++ b/test/sql/auth.test @@ -0,0 +1,129 @@ +# name: test/sql/auth.test +# description: test httpserver extension +# group: [httpserver] + +################################################################ +# Setup +################################################################ + +require httpserver + +statement ok +INSTALL http_client FROM community; + +statement ok +LOAD http_client; + +statement ok +INSTALL json; + +statement ok +LOAD json; + +################################################################ +# No auth test +################################################################ + +query I +SELECT httpserve_start('127.0.0.1', 4000, ''); +---- +HTTP server started on 127.0.0.1:4000 + +query TTT +WITH response AS (SELECT http_get('http://127.0.0.1:4000/?q=SELECT 123', MAP {}, MAP {}) response) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +200 OK {"123":"123"}\n + +query I +SELECT httpserve_stop(); +---- +HTTP server stopped + +################################################################ +# Basic auth test +################################################################ + +query I +SELECT httpserve_start('127.0.0.1', 4000, 'bob:pwd'); +---- +HTTP server started on 127.0.0.1:4000 + +query TTT +WITH response AS (SELECT http_get('http://127.0.0.1:4000/?q=SELECT 123', MAP {}, MAP {}) response) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +401 Unauthorized Unauthorized + +query TTT +WITH response AS ( + SELECT http_get( + 'http://127.0.0.1:4000/?q=SELECT 123', + MAP { + 'Authorization': CONCAT('Basic ', TO_BASE64('bob:pwd'::BLOB)), + }, + MAP {} + ) response +) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +200 OK {"123":"123"}\n + +query I +SELECT httpserve_stop(); +---- +HTTP server stopped + +################################################################ +# Token test +################################################################ + +query I +SELECT httpserve_start('127.0.0.1', 4000, 'my-api-key'); +---- +HTTP server started on 127.0.0.1:4000 + +query TTT +WITH response AS (SELECT http_get('http://127.0.0.1:4000/?q=SELECT 123', MAP {}, MAP {}) response) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +401 Unauthorized Unauthorized + +query TTT +WITH response AS ( + SELECT http_get( + 'http://127.0.0.1:4000/?q=SELECT 123', + MAP { + 'X-API-Key': 'my-api-key', + }, + MAP {} + ) response +) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +200 OK {"123":"123"}\n + +query I +SELECT httpserve_stop(); +---- +HTTP server stopped diff --git a/test/sql/basics.test b/test/sql/basics.test new file mode 100644 index 0000000..649504b --- /dev/null +++ b/test/sql/basics.test @@ -0,0 +1,69 @@ +# name: test/sql/basics.test +# description: test httpserver extension +# group: [httpserver] + +# Before we load the extension, this will fail +statement error +SELECT httpserve_start('127.0.0.1', 4000, ''); +---- +Catalog Error: Scalar Function with name httpserve_start does not exist! + +# Require statement will ensure this test is run with this extension loaded +require httpserver + +statement ok +INSTALL http_client FROM community; + +statement ok +LOAD http_client; + +statement ok +INSTALL json; + +statement ok +LOAD json; + +# The HTTP server is not available yet +query TTT +WITH response AS (SELECT http_get('http://127.0.0.1:4000/abc', MAP {}, MAP {}) response) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +-1 HTTP GET request failed. Connection error. (empty) + +# Start the HTTP server +query I +SELECT httpserve_start('127.0.0.1', 4000, ''); +---- +HTTP server started on 127.0.0.1:4000 + +# Simple request +query TTT +WITH response AS (SELECT http_get('http://127.0.0.1:4000/?q=SELECT ''World'' AS Hello', MAP {}, MAP {}) response) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +200 OK {"Hello":"World"}\n + +# Stop the HTTP server +query I +SELECT httpserve_stop(); +---- +HTTP server stopped + +# The HTTP server is not available anymore +query TTT +WITH response AS (SELECT http_get('http://127.0.0.1:4000/?q=SELECT ''World'' AS Hello', MAP {}, MAP {}) response) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +-1 HTTP GET request failed. Connection error. (empty) diff --git a/test/sql/quack.test b/test/sql/quack.test deleted file mode 100644 index 519a354..0000000 --- a/test/sql/quack.test +++ /dev/null @@ -1,23 +0,0 @@ -# name: test/sql/quack.test -# description: test quack extension -# group: [quack] - -# Before we load the extension, this will fail -statement error -SELECT quack('Sam'); ----- -Catalog Error: Scalar Function with name quack does not exist! - -# Require statement will ensure this test is run with this extension loaded -require quack - -# Confirm the extension works -query I -SELECT quack('Sam'); ----- -Quack Sam 🐥 - -query I -SELECT quack_openssl_version('Michael') ILIKE 'Quack Michael, my linked OpenSSL version is OpenSSL%'; ----- -true diff --git a/test/sql/simple-get.test b/test/sql/simple-get.test new file mode 100644 index 0000000..5c75fa5 --- /dev/null +++ b/test/sql/simple-get.test @@ -0,0 +1,52 @@ +# name: test/sql/simple-get.test +# description: test httpserver extension +# group: [httpserver] + +################################################################ +# Setup +################################################################ + +require httpserver + +statement ok +INSTALL http_client FROM community; + +statement ok +LOAD http_client; + +statement ok +INSTALL json; + +statement ok +LOAD json; + +query I +SELECT httpserve_start('127.0.0.1', 4000, ''); +---- +HTTP server started on 127.0.0.1:4000 + +################################################################ +# Tests +################################################################ + +# SQL request in `q` parameter +query TTT +WITH response AS (SELECT http_get('http://127.0.0.1:4000/?q=SELECT 123', MAP {}, MAP {}) response) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +200 OK {"123":"123"}\n + +# SQL request in `query` parameter +query TTT +WITH response AS (SELECT http_get('http://127.0.0.1:4000/?query=SELECT 123', MAP {}, MAP {}) response) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +200 OK {"123":"123"}\n