diff --git a/src/client/include/RestClientNative.hpp b/src/client/include/RestClientNative.hpp index 12524223..2ec18293 100644 --- a/src/client/include/RestClientNative.hpp +++ b/src/client/include/RestClientNative.hpp @@ -142,7 +142,6 @@ class RestClient : public ClientBase { std::mutex _subscriptionLock; std::map, httplib::Client> _subscription1; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - X509_STORE *_client_cert_store = nullptr; std::map, httplib::SSLClient> _subscription2; #endif @@ -166,12 +165,6 @@ class RestClient : public ClientBase { , _maxIoThreads(detail::find_argument_value([] { return MaxIoThreads(); }, initArgs...)) , _thread_pool(detail::find_argument_value([this] { return std::make_shared>(_name, _minIoThreads, _maxIoThreads); }, initArgs...)) , _caCertificate(detail::find_argument_value([] { return rest::DefaultCertificate().get(); }, initArgs...)) { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (_client_cert_store != nullptr) { - X509_STORE_free(_client_cert_store); - } - _client_cert_store = detail::createCertificateStore(_caCertificate); -#endif } ~RestClient() override { RestClient::stop(); }; @@ -285,7 +278,8 @@ class RestClient : public ClientBase { if (cmd.topic.scheme() && equal_with_case_ignore(cmd.topic.scheme().value(), "https")) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT httplib::SSLClient client(cmd.topic.hostName().value(), cmd.topic.port() ? cmd.topic.port().value() : 443); - client.set_ca_cert_store(_client_cert_store); + // client owns its certificate store and destroys it after use. create a store for each client + client.set_ca_cert_store(detail::createCertificateStore(_caCertificate)); client.enable_server_certificate_verification(CHECK_CERTIFICATES); callback(client); #else @@ -315,45 +309,68 @@ class RestClient : public ClientBase { || equal_with_case_ignore(*cmd.topic.scheme(), "https") #endif ) { - auto it = _subscription1.find(cmd.topic); - if (it == _subscription1.end()) { - auto &client = _subscription1.try_emplace(cmd.topic, httplib::Client(cmd.topic.hostName().value(), cmd.topic.port().value())).first->second; - client.set_follow_location(true); - - auto longPollingEndpoint = [&] { - if (!cmd.topic.queryParamMap().contains(LONG_POLLING_IDX_TAG)) { - return URI<>::factory(cmd.topic).addQueryParameter(LONG_POLLING_IDX_TAG, "Next").build(); - } else { - return URI<>::factory(cmd.topic).build(); - } - }(); - - const auto pollHeaders = getPreferredContentTypeHeader(longPollingEndpoint); - auto endpoint = longPollingEndpoint.relativeRef().value(); - client.set_read_timeout(cmd.timeout); // default keep-alive value - while (_run) { - auto redirect_get = [&client](auto url, auto headers) { - for (;;) { - auto result = client.Get(url, headers); - if (!result) return result; - - if (result->status >= 300 && result->status < 400) { - url = httplib::detail::decode_url(result.value().get_header_value("location"), true); - } else { - return result; - } + auto createNewSubscription = [&](auto &client) { + { + client.set_follow_location(true); + + auto longPollingEndpoint = [&] { + if (!cmd.topic.queryParamMap().contains(LONG_POLLING_IDX_TAG)) { + return URI<>::factory(cmd.topic).addQueryParameter(LONG_POLLING_IDX_TAG, "Next").build(); + } else { + return URI<>::factory(cmd.topic).build(); } - }; - if (const httplib::Result &result = redirect_get(endpoint, pollHeaders)) { - returnMdpMessage(cmd, result); - } else { // failed or server is down -> wait until retry - std::this_thread::sleep_for(cmd.timeout); // time-out until potential retry - if (_run) { - returnMdpMessage(cmd, result, fmt::format("Long-Polling-GET request failed for {}: {}", cmd.topic.str(), static_cast(result.error()))); + }(); + + const auto pollHeaders = getPreferredContentTypeHeader(longPollingEndpoint); + auto endpoint = longPollingEndpoint.relativeRef().value(); + client.set_read_timeout(cmd.timeout); // default keep-alive value + while (_run) { + auto redirect_get = [&client](auto url, auto headers) { + for (;;) { + auto result = client.Get(url, headers); + if (!result) return result; + + if (result->status >= 300 && result->status < 400) { + url = httplib::detail::decode_url(result.value().get_header_value("location"), true); + } else { + return result; + } + } + }; + if (const httplib::Result &result = redirect_get(endpoint, pollHeaders)) { + returnMdpMessage(cmd, result); + } else { // failed or server is down -> wait until retry + std::this_thread::sleep_for(cmd.timeout); // time-out until potential retry + if (_run) { + returnMdpMessage(cmd, result, fmt::format("Long-Polling-GET request failed for {}: {}", cmd.topic.str(), static_cast(result.error()))); + } } } } + }; + if (equal_with_case_ignore(*cmd.topic.scheme(), "http")) { + auto it = _subscription1.find(cmd.topic); + if (it == _subscription1.end()) { + _subscription1.emplace(cmd.topic, httplib::Client(cmd.topic.hostName().value(), cmd.topic.port().value())); + createNewSubscription(_subscription1.at(cmd.topic)); + } + } else { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (auto it = _subscription2.find(cmd.topic); it == _subscription2.end()) { + _subscription2.emplace( + std::piecewise_construct, + std::forward_as_tuple(cmd.topic), + std::forward_as_tuple(cmd.topic.hostName().value(), cmd.topic.port().value())); + auto &client = _subscription2.at(cmd.topic); + client.set_ca_cert_store(detail::createCertificateStore(_caCertificate)); + client.enable_server_certificate_verification(CHECK_CERTIFICATES); + createNewSubscription(_subscription2.at(cmd.topic)); + } +#else + throw std::invalid_argument("https is not supported"); +#endif } + } else { throw std::invalid_argument(fmt::format("unsupported scheme '{}' for requested subscription '{}'", cmd.topic.scheme(), cmd.topic.str())); } diff --git a/src/client/include/RestDefaultClientCertificates.hpp b/src/client/include/RestDefaultClientCertificates.hpp index 7232150f..61783625 100644 --- a/src/client/include/RestDefaultClientCertificates.hpp +++ b/src/client/include/RestDefaultClientCertificates.hpp @@ -18,6 +18,15 @@ class DefaultCertificate { _concatenated_certificates += root_certificates[1]; _concatenated_certificates += root_certificates[2]; _concatenated_certificates += root_certificates[3]; + + if (auto filename = std::getenv("OPENCMW_REST_CERT_FILE"); filename) { + std::ifstream ifs{ filename }; + if (!ifs.is_open()) { + std::string cert; + ifs >> cert; + _concatenated_certificates += cert; + } + } } constexpr std::string get() const noexcept { return _concatenated_certificates; diff --git a/src/client/test/RestClient_tests.cpp b/src/client/test/RestClient_tests.cpp index 0f30d46c..24561aeb 100644 --- a/src/client/test/RestClient_tests.cpp +++ b/src/client/test/RestClient_tests.cpp @@ -73,8 +73,6 @@ TEST_CASE("Basic Rest Client Constructor and API Tests", "[Client]") { RestClient client5("clientName", DefaultContentTypeHeader(MIME::HTML), MinIoThreads(2), MaxIoThreads(5), ClientCertificates(testCertificate)); REQUIRE(client5.defaultMimeType() == MIME::HTML); REQUIRE(client5.threadPool()->poolName() == "clientName"); - - REQUIRE_THROWS_AS(RestClient(ClientCertificates("Invalid Certificate Format")), std::invalid_argument); } TEST_CASE("Basic Rest Client Get/Set Test - HTTP", "[Client]") { @@ -123,6 +121,73 @@ TEST_CASE("Basic Rest Client Get/Set Test - HTTP", "[Client]") { } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT +TEST_CASE("Multiple Rest Client Get/Set Test - HTTPS", "[Client]") { + using namespace opencmw::client; + RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate)); + REQUIRE(RestClient::CHECK_CERTIFICATES); + RestClient::CHECK_CERTIFICATES = true; // 'false' disables certificate check + REQUIRE(client.name() == "TestSSLClient"); + REQUIRE(client.defaultMimeType() == MIME::JSON); + + // HTTP + X509 *cert = opencmw::client::detail::readServerCertificateFromFile(testServerCertificates.serverCertificate); + EVP_PKEY *pkey = opencmw::client::detail::readServerPrivateKeyFromFile(testServerCertificates.serverKey); + if (const X509_STORE *ca_store = opencmw::client::detail::createCertificateStore(testServerCertificates.caCertificate); !cert || !pkey || !ca_store) { + FAIL(fmt::format("Failed to load certificate: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + httplib::SSLServer server(cert, pkey); + + std::string acceptHeader; + server.Get("/endPoint", [&acceptHeader](const httplib::Request &req, httplib::Response &res) { + if (req.headers.contains("accept")) { + acceptHeader = req.headers.find("accept")->second; + } else { + FAIL("no accept headers found"); + } + res.set_content("Hello World!", acceptHeader); + }); + client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); }); + while (!server.is_running()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + REQUIRE(server.is_running()); + + std::array, 4> dones; + dones[0] = false; + dones[1] = false; + dones[2] = false; + dones[3] = false; + std::atomic counter{ 0 }; + auto makeCommand = [&]() { + IoBuffer data; + data.put('A'); + data.put('B'); + data.put('C'); + data.put(0); + + Command command; + command.command = mdp::Command::Get; + command.topic = URI("https://localhost:8080/endPoint"); + command.data = std::move(data); + command.callback = [&dones, &counter](const mdp::Message &/*rep*/) { + int currentCounter = counter.fetch_add(1, std::memory_order_relaxed); + dones[currentCounter].store(true, std::memory_order_release); + // Assuming you have access to 'done' variable, uncomment the following line + dones[currentCounter].notify_all(); + }; + client.request(command); + }; + for (int i = 0; i < 4; i++) + makeCommand(); + + for (auto &done : dones) { + done.wait(false); + } + REQUIRE(std::ranges::all_of(dones, [](auto &done) { return done.load(std::memory_order_acquire); })); + REQUIRE(acceptHeader == MIME::JSON.typeName()); + server.stop(); +} + TEST_CASE("Basic Rest Client Get/Set Test - HTTPS", "[Client]") { using namespace opencmw::client; RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate)); @@ -296,4 +361,108 @@ TEST_CASE("Basic Rest Client Subscribe/Unsubscribe Test", "[Client]") { std::cout << "server stopped" << std::endl; } +TEST_CASE("Basic Rest Client Subscribe/Unsubscribe Test HTTPS", "[Client]") { + // HTTP + X509 *cert = opencmw::client::detail::readServerCertificateFromFile(testServerCertificates.serverCertificate); + EVP_PKEY *pkey = opencmw::client::detail::readServerPrivateKeyFromFile(testServerCertificates.serverKey); + if (const X509_STORE *ca_store = opencmw::client::detail::createCertificateStore(testServerCertificates.caCertificate); !cert || !pkey || !ca_store) { + FAIL(fmt::format("Failed to load certificate: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + using namespace opencmw::client; + + std::atomic updateCounter{ 0 }; + detail::EventDispatcher eventDispatcher; + httplib::SSLServer server(cert, pkey); + server.Get("/event", [&eventDispatcher, &updateCounter](const httplib::Request &req, httplib::Response &res) { + DEBUG_LOG("Server received request"); + auto acceptType = req.headers.find("accept"); + if (acceptType == req.headers.end() || MIME::EVENT_STREAM.typeName() != acceptType->second) { // non-SSE request -> return default response +#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16)) + res.set_content(fmt::format("update counter = {}", updateCounter.load()), MIME::TEXT); +#else + res.set_content(fmt::format("update counter = {}", updateCounter.load()), std::string(MIME::TEXT.typeName())); +#endif + return; + } else { + fmt::print("server received SSE request on path '{}' body = '{}'\n", req.path, req.body); +#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16)) + res.set_chunked_content_provider(MIME::EVENT_STREAM, [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) { +#else + res.set_chunked_content_provider(std::string(MIME::EVENT_STREAM.typeName()), [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) { +#endif + eventDispatcher.wait_event(sink); + return true; + }); + } + }); + server.Get("/endPoint", [](const httplib::Request &req, httplib::Response &res) { + fmt::print("server received request on path '{}' body = '{}'\n", req.path, req.body); + res.set_content("Hello World!", "text/plain"); + }); + + RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate)); + + client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); }); + while (!server.is_running()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + REQUIRE(server.is_running()); + REQUIRE(RestClient::CHECK_CERTIFICATES); + RestClient::CHECK_CERTIFICATES = true; // 'false' disables certificate check + REQUIRE(client.name() == "TestSSLClient"); + REQUIRE(client.defaultMimeType() == MIME::JSON); + + std::atomic receivedRegular(0); + std::atomic receivedError(0); + IoBuffer data; + data.put('A'); + data.put('B'); + data.put('C'); + data.put(0); + + Command command; + command.command = mdp::Command::Subscribe; + command.topic = URI("https://localhost:8080/event"); + command.data = std::move(data); + command.callback = [&receivedRegular, &receivedError](const mdp::Message &rep) { + fmt::print("SSE client received reply = '{}' - body size: '{}'\n", rep.data.asString(), rep.data.size()); + if (rep.error.size() == 0) { + receivedRegular.fetch_add(1, std::memory_order_relaxed); + } else { + receivedError.fetch_add(1, std::memory_order_relaxed); + } + receivedRegular.notify_all(); + receivedError.notify_all(); + }; + + client.request(command); + + std::cout << "client request launched" << std::endl; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + eventDispatcher.send_event("test-event meta data"); + std::jthread dispatcher([&updateCounter, &eventDispatcher] { + while (updateCounter < 5) { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + eventDispatcher.send_event(fmt::format("test-event {}", updateCounter++)); + } + }); + dispatcher.join(); + + while (receivedRegular.load(std::memory_order_relaxed) < 5) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + std::cout << "done waiting" << std::endl; + REQUIRE(receivedRegular.load(std::memory_order_acquire) >= 5); + + command.command = mdp::Command::Unsubscribe; + client.request(command); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + std::cout << "done Unsubscribe" << std::endl; + + client.stop(); + server.stop(); + eventDispatcher.send_event(fmt::format("test-event {}", updateCounter++)); + std::cout << "server stopped" << std::endl; +} + } // namespace opencmw::rest_client_test