From ac69eaf402259a6cf74865c583a632703f7cfcff Mon Sep 17 00:00:00 2001 From: Alexander Busse Date: Mon, 26 Feb 2024 15:02:31 +0100 Subject: [PATCH] Handle https subscriptions, fix cert store bug https subscriptions were not handled, create a SSLCLient and reuse existing code for subscribing. the certificate store was shared between all SslClients in our RestClient, this does not work, because the client owns the store and deletes it on its own destruction. this patch also adds the certificate specified by OPENCMW_REST_CERT_FILE to our RestDefaultCLientCertificates to make it trusted. --- src/client/include/RestClientNative.hpp | 101 +++++----- .../include/RestDefaultClientCertificates.hpp | 9 + src/client/test/RestClient_tests.cpp | 173 +++++++++++++++++- 3 files changed, 239 insertions(+), 44 deletions(-) diff --git a/src/client/include/RestClientNative.hpp b/src/client/include/RestClientNative.hpp index 12524223..2ec18293 100644 --- a/src/client/include/RestClientNative.hpp +++ b/src/client/include/RestClientNative.hpp @@ -142,7 +142,6 @@ class RestClient : public ClientBase { std::mutex _subscriptionLock; std::map, httplib::Client> _subscription1; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - X509_STORE *_client_cert_store = nullptr; std::map, httplib::SSLClient> _subscription2; #endif @@ -166,12 +165,6 @@ class RestClient : public ClientBase { , _maxIoThreads(detail::find_argument_value([] { return MaxIoThreads(); }, initArgs...)) , _thread_pool(detail::find_argument_value([this] { return std::make_shared>(_name, _minIoThreads, _maxIoThreads); }, initArgs...)) , _caCertificate(detail::find_argument_value([] { return rest::DefaultCertificate().get(); }, initArgs...)) { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (_client_cert_store != nullptr) { - X509_STORE_free(_client_cert_store); - } - _client_cert_store = detail::createCertificateStore(_caCertificate); -#endif } ~RestClient() override { RestClient::stop(); }; @@ -285,7 +278,8 @@ class RestClient : public ClientBase { if (cmd.topic.scheme() && equal_with_case_ignore(cmd.topic.scheme().value(), "https")) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT httplib::SSLClient client(cmd.topic.hostName().value(), cmd.topic.port() ? cmd.topic.port().value() : 443); - client.set_ca_cert_store(_client_cert_store); + // client owns its certificate store and destroys it after use. create a store for each client + client.set_ca_cert_store(detail::createCertificateStore(_caCertificate)); client.enable_server_certificate_verification(CHECK_CERTIFICATES); callback(client); #else @@ -315,45 +309,68 @@ class RestClient : public ClientBase { || equal_with_case_ignore(*cmd.topic.scheme(), "https") #endif ) { - auto it = _subscription1.find(cmd.topic); - if (it == _subscription1.end()) { - auto &client = _subscription1.try_emplace(cmd.topic, httplib::Client(cmd.topic.hostName().value(), cmd.topic.port().value())).first->second; - client.set_follow_location(true); - - auto longPollingEndpoint = [&] { - if (!cmd.topic.queryParamMap().contains(LONG_POLLING_IDX_TAG)) { - return URI<>::factory(cmd.topic).addQueryParameter(LONG_POLLING_IDX_TAG, "Next").build(); - } else { - return URI<>::factory(cmd.topic).build(); - } - }(); - - const auto pollHeaders = getPreferredContentTypeHeader(longPollingEndpoint); - auto endpoint = longPollingEndpoint.relativeRef().value(); - client.set_read_timeout(cmd.timeout); // default keep-alive value - while (_run) { - auto redirect_get = [&client](auto url, auto headers) { - for (;;) { - auto result = client.Get(url, headers); - if (!result) return result; - - if (result->status >= 300 && result->status < 400) { - url = httplib::detail::decode_url(result.value().get_header_value("location"), true); - } else { - return result; - } + auto createNewSubscription = [&](auto &client) { + { + client.set_follow_location(true); + + auto longPollingEndpoint = [&] { + if (!cmd.topic.queryParamMap().contains(LONG_POLLING_IDX_TAG)) { + return URI<>::factory(cmd.topic).addQueryParameter(LONG_POLLING_IDX_TAG, "Next").build(); + } else { + return URI<>::factory(cmd.topic).build(); } - }; - if (const httplib::Result &result = redirect_get(endpoint, pollHeaders)) { - returnMdpMessage(cmd, result); - } else { // failed or server is down -> wait until retry - std::this_thread::sleep_for(cmd.timeout); // time-out until potential retry - if (_run) { - returnMdpMessage(cmd, result, fmt::format("Long-Polling-GET request failed for {}: {}", cmd.topic.str(), static_cast(result.error()))); + }(); + + const auto pollHeaders = getPreferredContentTypeHeader(longPollingEndpoint); + auto endpoint = longPollingEndpoint.relativeRef().value(); + client.set_read_timeout(cmd.timeout); // default keep-alive value + while (_run) { + auto redirect_get = [&client](auto url, auto headers) { + for (;;) { + auto result = client.Get(url, headers); + if (!result) return result; + + if (result->status >= 300 && result->status < 400) { + url = httplib::detail::decode_url(result.value().get_header_value("location"), true); + } else { + return result; + } + } + }; + if (const httplib::Result &result = redirect_get(endpoint, pollHeaders)) { + returnMdpMessage(cmd, result); + } else { // failed or server is down -> wait until retry + std::this_thread::sleep_for(cmd.timeout); // time-out until potential retry + if (_run) { + returnMdpMessage(cmd, result, fmt::format("Long-Polling-GET request failed for {}: {}", cmd.topic.str(), static_cast(result.error()))); + } } } } + }; + if (equal_with_case_ignore(*cmd.topic.scheme(), "http")) { + auto it = _subscription1.find(cmd.topic); + if (it == _subscription1.end()) { + _subscription1.emplace(cmd.topic, httplib::Client(cmd.topic.hostName().value(), cmd.topic.port().value())); + createNewSubscription(_subscription1.at(cmd.topic)); + } + } else { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (auto it = _subscription2.find(cmd.topic); it == _subscription2.end()) { + _subscription2.emplace( + std::piecewise_construct, + std::forward_as_tuple(cmd.topic), + std::forward_as_tuple(cmd.topic.hostName().value(), cmd.topic.port().value())); + auto &client = _subscription2.at(cmd.topic); + client.set_ca_cert_store(detail::createCertificateStore(_caCertificate)); + client.enable_server_certificate_verification(CHECK_CERTIFICATES); + createNewSubscription(_subscription2.at(cmd.topic)); + } +#else + throw std::invalid_argument("https is not supported"); +#endif } + } else { throw std::invalid_argument(fmt::format("unsupported scheme '{}' for requested subscription '{}'", cmd.topic.scheme(), cmd.topic.str())); } diff --git a/src/client/include/RestDefaultClientCertificates.hpp b/src/client/include/RestDefaultClientCertificates.hpp index 7232150f..61783625 100644 --- a/src/client/include/RestDefaultClientCertificates.hpp +++ b/src/client/include/RestDefaultClientCertificates.hpp @@ -18,6 +18,15 @@ class DefaultCertificate { _concatenated_certificates += root_certificates[1]; _concatenated_certificates += root_certificates[2]; _concatenated_certificates += root_certificates[3]; + + if (auto filename = std::getenv("OPENCMW_REST_CERT_FILE"); filename) { + std::ifstream ifs{ filename }; + if (!ifs.is_open()) { + std::string cert; + ifs >> cert; + _concatenated_certificates += cert; + } + } } constexpr std::string get() const noexcept { return _concatenated_certificates; diff --git a/src/client/test/RestClient_tests.cpp b/src/client/test/RestClient_tests.cpp index 0f30d46c..24561aeb 100644 --- a/src/client/test/RestClient_tests.cpp +++ b/src/client/test/RestClient_tests.cpp @@ -73,8 +73,6 @@ TEST_CASE("Basic Rest Client Constructor and API Tests", "[Client]") { RestClient client5("clientName", DefaultContentTypeHeader(MIME::HTML), MinIoThreads(2), MaxIoThreads(5), ClientCertificates(testCertificate)); REQUIRE(client5.defaultMimeType() == MIME::HTML); REQUIRE(client5.threadPool()->poolName() == "clientName"); - - REQUIRE_THROWS_AS(RestClient(ClientCertificates("Invalid Certificate Format")), std::invalid_argument); } TEST_CASE("Basic Rest Client Get/Set Test - HTTP", "[Client]") { @@ -123,6 +121,73 @@ TEST_CASE("Basic Rest Client Get/Set Test - HTTP", "[Client]") { } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT +TEST_CASE("Multiple Rest Client Get/Set Test - HTTPS", "[Client]") { + using namespace opencmw::client; + RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate)); + REQUIRE(RestClient::CHECK_CERTIFICATES); + RestClient::CHECK_CERTIFICATES = true; // 'false' disables certificate check + REQUIRE(client.name() == "TestSSLClient"); + REQUIRE(client.defaultMimeType() == MIME::JSON); + + // HTTP + X509 *cert = opencmw::client::detail::readServerCertificateFromFile(testServerCertificates.serverCertificate); + EVP_PKEY *pkey = opencmw::client::detail::readServerPrivateKeyFromFile(testServerCertificates.serverKey); + if (const X509_STORE *ca_store = opencmw::client::detail::createCertificateStore(testServerCertificates.caCertificate); !cert || !pkey || !ca_store) { + FAIL(fmt::format("Failed to load certificate: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + httplib::SSLServer server(cert, pkey); + + std::string acceptHeader; + server.Get("/endPoint", [&acceptHeader](const httplib::Request &req, httplib::Response &res) { + if (req.headers.contains("accept")) { + acceptHeader = req.headers.find("accept")->second; + } else { + FAIL("no accept headers found"); + } + res.set_content("Hello World!", acceptHeader); + }); + client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); }); + while (!server.is_running()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + REQUIRE(server.is_running()); + + std::array, 4> dones; + dones[0] = false; + dones[1] = false; + dones[2] = false; + dones[3] = false; + std::atomic counter{ 0 }; + auto makeCommand = [&]() { + IoBuffer data; + data.put('A'); + data.put('B'); + data.put('C'); + data.put(0); + + Command command; + command.command = mdp::Command::Get; + command.topic = URI("https://localhost:8080/endPoint"); + command.data = std::move(data); + command.callback = [&dones, &counter](const mdp::Message &/*rep*/) { + int currentCounter = counter.fetch_add(1, std::memory_order_relaxed); + dones[currentCounter].store(true, std::memory_order_release); + // Assuming you have access to 'done' variable, uncomment the following line + dones[currentCounter].notify_all(); + }; + client.request(command); + }; + for (int i = 0; i < 4; i++) + makeCommand(); + + for (auto &done : dones) { + done.wait(false); + } + REQUIRE(std::ranges::all_of(dones, [](auto &done) { return done.load(std::memory_order_acquire); })); + REQUIRE(acceptHeader == MIME::JSON.typeName()); + server.stop(); +} + TEST_CASE("Basic Rest Client Get/Set Test - HTTPS", "[Client]") { using namespace opencmw::client; RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate)); @@ -296,4 +361,108 @@ TEST_CASE("Basic Rest Client Subscribe/Unsubscribe Test", "[Client]") { std::cout << "server stopped" << std::endl; } +TEST_CASE("Basic Rest Client Subscribe/Unsubscribe Test HTTPS", "[Client]") { + // HTTP + X509 *cert = opencmw::client::detail::readServerCertificateFromFile(testServerCertificates.serverCertificate); + EVP_PKEY *pkey = opencmw::client::detail::readServerPrivateKeyFromFile(testServerCertificates.serverKey); + if (const X509_STORE *ca_store = opencmw::client::detail::createCertificateStore(testServerCertificates.caCertificate); !cert || !pkey || !ca_store) { + FAIL(fmt::format("Failed to load certificate: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + using namespace opencmw::client; + + std::atomic updateCounter{ 0 }; + detail::EventDispatcher eventDispatcher; + httplib::SSLServer server(cert, pkey); + server.Get("/event", [&eventDispatcher, &updateCounter](const httplib::Request &req, httplib::Response &res) { + DEBUG_LOG("Server received request"); + auto acceptType = req.headers.find("accept"); + if (acceptType == req.headers.end() || MIME::EVENT_STREAM.typeName() != acceptType->second) { // non-SSE request -> return default response +#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16)) + res.set_content(fmt::format("update counter = {}", updateCounter.load()), MIME::TEXT); +#else + res.set_content(fmt::format("update counter = {}", updateCounter.load()), std::string(MIME::TEXT.typeName())); +#endif + return; + } else { + fmt::print("server received SSE request on path '{}' body = '{}'\n", req.path, req.body); +#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16)) + res.set_chunked_content_provider(MIME::EVENT_STREAM, [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) { +#else + res.set_chunked_content_provider(std::string(MIME::EVENT_STREAM.typeName()), [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) { +#endif + eventDispatcher.wait_event(sink); + return true; + }); + } + }); + server.Get("/endPoint", [](const httplib::Request &req, httplib::Response &res) { + fmt::print("server received request on path '{}' body = '{}'\n", req.path, req.body); + res.set_content("Hello World!", "text/plain"); + }); + + RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate)); + + client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); }); + while (!server.is_running()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + REQUIRE(server.is_running()); + REQUIRE(RestClient::CHECK_CERTIFICATES); + RestClient::CHECK_CERTIFICATES = true; // 'false' disables certificate check + REQUIRE(client.name() == "TestSSLClient"); + REQUIRE(client.defaultMimeType() == MIME::JSON); + + std::atomic receivedRegular(0); + std::atomic receivedError(0); + IoBuffer data; + data.put('A'); + data.put('B'); + data.put('C'); + data.put(0); + + Command command; + command.command = mdp::Command::Subscribe; + command.topic = URI("https://localhost:8080/event"); + command.data = std::move(data); + command.callback = [&receivedRegular, &receivedError](const mdp::Message &rep) { + fmt::print("SSE client received reply = '{}' - body size: '{}'\n", rep.data.asString(), rep.data.size()); + if (rep.error.size() == 0) { + receivedRegular.fetch_add(1, std::memory_order_relaxed); + } else { + receivedError.fetch_add(1, std::memory_order_relaxed); + } + receivedRegular.notify_all(); + receivedError.notify_all(); + }; + + client.request(command); + + std::cout << "client request launched" << std::endl; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + eventDispatcher.send_event("test-event meta data"); + std::jthread dispatcher([&updateCounter, &eventDispatcher] { + while (updateCounter < 5) { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + eventDispatcher.send_event(fmt::format("test-event {}", updateCounter++)); + } + }); + dispatcher.join(); + + while (receivedRegular.load(std::memory_order_relaxed) < 5) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + std::cout << "done waiting" << std::endl; + REQUIRE(receivedRegular.load(std::memory_order_acquire) >= 5); + + command.command = mdp::Command::Unsubscribe; + client.request(command); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + std::cout << "done Unsubscribe" << std::endl; + + client.stop(); + server.stop(); + eventDispatcher.send_event(fmt::format("test-event {}", updateCounter++)); + std::cout << "server stopped" << std::endl; +} + } // namespace opencmw::rest_client_test