Skip to content

Commit

Permalink
Handle https subscriptions, fix cert store bug
Browse files Browse the repository at this point in the history
https subscriptions were not handled, create a SSLCLient and reuse
existing code for subscribing.

the certificate store was shared between all SslClients in our
RestClient, this does not work, because the client owns the store and
deletes it on its own destruction.

this patch also adds the certificate specified by OPENCMW_REST_CERT_FILE
to our RestDefaultCLientCertificates to make it trusted.
  • Loading branch information
Alexander Busse authored and wirew0rm committed Feb 27, 2024
1 parent a7a7c5c commit ac69eaf
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 44 deletions.
101 changes: 59 additions & 42 deletions src/client/include/RestClientNative.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ class RestClient : public ClientBase {
std::mutex _subscriptionLock;
std::map<URI<STRICT>, httplib::Client> _subscription1;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
X509_STORE *_client_cert_store = nullptr;
std::map<URI<STRICT>, httplib::SSLClient> _subscription2;
#endif

Expand All @@ -166,12 +165,6 @@ class RestClient : public ClientBase {
, _maxIoThreads(detail::find_argument_value<true, MaxIoThreads>([] { return MaxIoThreads(); }, initArgs...))
, _thread_pool(detail::find_argument_value<true, ThreadPoolType>([this] { return std::make_shared<BasicThreadPool<IO_BOUND>>(_name, _minIoThreads, _maxIoThreads); }, initArgs...))
, _caCertificate(detail::find_argument_value<true, ClientCertificates>([] { return rest::DefaultCertificate().get(); }, initArgs...)) {
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
if (_client_cert_store != nullptr) {
X509_STORE_free(_client_cert_store);
}
_client_cert_store = detail::createCertificateStore(_caCertificate);
#endif
}
~RestClient() override { RestClient::stop(); };

Expand Down Expand Up @@ -285,7 +278,8 @@ class RestClient : public ClientBase {
if (cmd.topic.scheme() && equal_with_case_ignore(cmd.topic.scheme().value(), "https")) {
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
httplib::SSLClient client(cmd.topic.hostName().value(), cmd.topic.port() ? cmd.topic.port().value() : 443);
client.set_ca_cert_store(_client_cert_store);
// client owns its certificate store and destroys it after use. create a store for each client
client.set_ca_cert_store(detail::createCertificateStore(_caCertificate));
client.enable_server_certificate_verification(CHECK_CERTIFICATES);
callback(client);
#else
Expand Down Expand Up @@ -315,45 +309,68 @@ class RestClient : public ClientBase {
|| equal_with_case_ignore(*cmd.topic.scheme(), "https")
#endif
) {
auto it = _subscription1.find(cmd.topic);
if (it == _subscription1.end()) {
auto &client = _subscription1.try_emplace(cmd.topic, httplib::Client(cmd.topic.hostName().value(), cmd.topic.port().value())).first->second;
client.set_follow_location(true);

auto longPollingEndpoint = [&] {
if (!cmd.topic.queryParamMap().contains(LONG_POLLING_IDX_TAG)) {
return URI<>::factory(cmd.topic).addQueryParameter(LONG_POLLING_IDX_TAG, "Next").build();
} else {
return URI<>::factory(cmd.topic).build();
}
}();

const auto pollHeaders = getPreferredContentTypeHeader(longPollingEndpoint);
auto endpoint = longPollingEndpoint.relativeRef().value();
client.set_read_timeout(cmd.timeout); // default keep-alive value
while (_run) {
auto redirect_get = [&client](auto url, auto headers) {
for (;;) {
auto result = client.Get(url, headers);
if (!result) return result;

if (result->status >= 300 && result->status < 400) {
url = httplib::detail::decode_url(result.value().get_header_value("location"), true);
} else {
return result;
}
auto createNewSubscription = [&](auto &client) {
{
client.set_follow_location(true);

auto longPollingEndpoint = [&] {
if (!cmd.topic.queryParamMap().contains(LONG_POLLING_IDX_TAG)) {
return URI<>::factory(cmd.topic).addQueryParameter(LONG_POLLING_IDX_TAG, "Next").build();
} else {
return URI<>::factory(cmd.topic).build();
}
};
if (const httplib::Result &result = redirect_get(endpoint, pollHeaders)) {
returnMdpMessage(cmd, result);
} else { // failed or server is down -> wait until retry
std::this_thread::sleep_for(cmd.timeout); // time-out until potential retry
if (_run) {
returnMdpMessage(cmd, result, fmt::format("Long-Polling-GET request failed for {}: {}", cmd.topic.str(), static_cast<int>(result.error())));
}();

const auto pollHeaders = getPreferredContentTypeHeader(longPollingEndpoint);
auto endpoint = longPollingEndpoint.relativeRef().value();
client.set_read_timeout(cmd.timeout); // default keep-alive value
while (_run) {
auto redirect_get = [&client](auto url, auto headers) {
for (;;) {
auto result = client.Get(url, headers);
if (!result) return result;

if (result->status >= 300 && result->status < 400) {
url = httplib::detail::decode_url(result.value().get_header_value("location"), true);
} else {
return result;
}
}
};
if (const httplib::Result &result = redirect_get(endpoint, pollHeaders)) {
returnMdpMessage(cmd, result);
} else { // failed or server is down -> wait until retry
std::this_thread::sleep_for(cmd.timeout); // time-out until potential retry
if (_run) {
returnMdpMessage(cmd, result, fmt::format("Long-Polling-GET request failed for {}: {}", cmd.topic.str(), static_cast<int>(result.error())));
}
}
}
}
};
if (equal_with_case_ignore(*cmd.topic.scheme(), "http")) {
auto it = _subscription1.find(cmd.topic);
if (it == _subscription1.end()) {
_subscription1.emplace(cmd.topic, httplib::Client(cmd.topic.hostName().value(), cmd.topic.port().value()));
createNewSubscription(_subscription1.at(cmd.topic));
}
} else {
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
if (auto it = _subscription2.find(cmd.topic); it == _subscription2.end()) {
_subscription2.emplace(
std::piecewise_construct,
std::forward_as_tuple(cmd.topic),
std::forward_as_tuple(cmd.topic.hostName().value(), cmd.topic.port().value()));
auto &client = _subscription2.at(cmd.topic);
client.set_ca_cert_store(detail::createCertificateStore(_caCertificate));
client.enable_server_certificate_verification(CHECK_CERTIFICATES);
createNewSubscription(_subscription2.at(cmd.topic));
}
#else
throw std::invalid_argument("https is not supported");
#endif
}

} else {
throw std::invalid_argument(fmt::format("unsupported scheme '{}' for requested subscription '{}'", cmd.topic.scheme(), cmd.topic.str()));
}
Expand Down
9 changes: 9 additions & 0 deletions src/client/include/RestDefaultClientCertificates.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ class DefaultCertificate {
_concatenated_certificates += root_certificates[1];
_concatenated_certificates += root_certificates[2];
_concatenated_certificates += root_certificates[3];

if (auto filename = std::getenv("OPENCMW_REST_CERT_FILE"); filename) {
std::ifstream ifs{ filename };
if (!ifs.is_open()) {
std::string cert;
ifs >> cert;
_concatenated_certificates += cert;
}
}
}
constexpr std::string get() const noexcept {
return _concatenated_certificates;
Expand Down
173 changes: 171 additions & 2 deletions src/client/test/RestClient_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ TEST_CASE("Basic Rest Client Constructor and API Tests", "[Client]") {
RestClient client5("clientName", DefaultContentTypeHeader(MIME::HTML), MinIoThreads(2), MaxIoThreads(5), ClientCertificates(testCertificate));
REQUIRE(client5.defaultMimeType() == MIME::HTML);
REQUIRE(client5.threadPool()->poolName() == "clientName");

REQUIRE_THROWS_AS(RestClient(ClientCertificates("Invalid Certificate Format")), std::invalid_argument);
}

TEST_CASE("Basic Rest Client Get/Set Test - HTTP", "[Client]") {
Expand Down Expand Up @@ -123,6 +121,73 @@ TEST_CASE("Basic Rest Client Get/Set Test - HTTP", "[Client]") {
}

#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
TEST_CASE("Multiple Rest Client Get/Set Test - HTTPS", "[Client]") {
using namespace opencmw::client;
RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate));
REQUIRE(RestClient::CHECK_CERTIFICATES);
RestClient::CHECK_CERTIFICATES = true; // 'false' disables certificate check
REQUIRE(client.name() == "TestSSLClient");
REQUIRE(client.defaultMimeType() == MIME::JSON);

// HTTP
X509 *cert = opencmw::client::detail::readServerCertificateFromFile(testServerCertificates.serverCertificate);
EVP_PKEY *pkey = opencmw::client::detail::readServerPrivateKeyFromFile(testServerCertificates.serverKey);
if (const X509_STORE *ca_store = opencmw::client::detail::createCertificateStore(testServerCertificates.caCertificate); !cert || !pkey || !ca_store) {
FAIL(fmt::format("Failed to load certificate: {}", ERR_error_string(ERR_get_error(), nullptr)));
}
httplib::SSLServer server(cert, pkey);

std::string acceptHeader;
server.Get("/endPoint", [&acceptHeader](const httplib::Request &req, httplib::Response &res) {
if (req.headers.contains("accept")) {
acceptHeader = req.headers.find("accept")->second;
} else {
FAIL("no accept headers found");
}
res.set_content("Hello World!", acceptHeader);
});
client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); });
while (!server.is_running()) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
REQUIRE(server.is_running());

std::array<std::atomic<bool>, 4> dones;
dones[0] = false;
dones[1] = false;
dones[2] = false;
dones[3] = false;
std::atomic<int> counter{ 0 };
auto makeCommand = [&]() {
IoBuffer data;
data.put('A');
data.put('B');
data.put('C');
data.put(0);

Command command;
command.command = mdp::Command::Get;
command.topic = URI<STRICT>("https://localhost:8080/endPoint");
command.data = std::move(data);
command.callback = [&dones, &counter](const mdp::Message &/*rep*/) {
int currentCounter = counter.fetch_add(1, std::memory_order_relaxed);
dones[currentCounter].store(true, std::memory_order_release);
// Assuming you have access to 'done' variable, uncomment the following line
dones[currentCounter].notify_all();
};
client.request(command);
};
for (int i = 0; i < 4; i++)
makeCommand();

for (auto &done : dones) {
done.wait(false);
}
REQUIRE(std::ranges::all_of(dones, [](auto &done) { return done.load(std::memory_order_acquire); }));
REQUIRE(acceptHeader == MIME::JSON.typeName());
server.stop();
}

TEST_CASE("Basic Rest Client Get/Set Test - HTTPS", "[Client]") {
using namespace opencmw::client;
RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate));
Expand Down Expand Up @@ -296,4 +361,108 @@ TEST_CASE("Basic Rest Client Subscribe/Unsubscribe Test", "[Client]") {
std::cout << "server stopped" << std::endl;
}

TEST_CASE("Basic Rest Client Subscribe/Unsubscribe Test HTTPS", "[Client]") {
// HTTP
X509 *cert = opencmw::client::detail::readServerCertificateFromFile(testServerCertificates.serverCertificate);
EVP_PKEY *pkey = opencmw::client::detail::readServerPrivateKeyFromFile(testServerCertificates.serverKey);
if (const X509_STORE *ca_store = opencmw::client::detail::createCertificateStore(testServerCertificates.caCertificate); !cert || !pkey || !ca_store) {
FAIL(fmt::format("Failed to load certificate: {}", ERR_error_string(ERR_get_error(), nullptr)));
}
using namespace opencmw::client;

std::atomic<int> updateCounter{ 0 };
detail::EventDispatcher eventDispatcher;
httplib::SSLServer server(cert, pkey);
server.Get("/event", [&eventDispatcher, &updateCounter](const httplib::Request &req, httplib::Response &res) {
DEBUG_LOG("Server received request");
auto acceptType = req.headers.find("accept");
if (acceptType == req.headers.end() || MIME::EVENT_STREAM.typeName() != acceptType->second) { // non-SSE request -> return default response
#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16))
res.set_content(fmt::format("update counter = {}", updateCounter.load()), MIME::TEXT);
#else
res.set_content(fmt::format("update counter = {}", updateCounter.load()), std::string(MIME::TEXT.typeName()));
#endif
return;
} else {
fmt::print("server received SSE request on path '{}' body = '{}'\n", req.path, req.body);
#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16))
res.set_chunked_content_provider(MIME::EVENT_STREAM, [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) {
#else
res.set_chunked_content_provider(std::string(MIME::EVENT_STREAM.typeName()), [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) {
#endif
eventDispatcher.wait_event(sink);
return true;
});
}
});
server.Get("/endPoint", [](const httplib::Request &req, httplib::Response &res) {
fmt::print("server received request on path '{}' body = '{}'\n", req.path, req.body);
res.set_content("Hello World!", "text/plain");
});

RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate));

client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); });
while (!server.is_running()) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
REQUIRE(server.is_running());
REQUIRE(RestClient::CHECK_CERTIFICATES);
RestClient::CHECK_CERTIFICATES = true; // 'false' disables certificate check
REQUIRE(client.name() == "TestSSLClient");
REQUIRE(client.defaultMimeType() == MIME::JSON);

std::atomic<int> receivedRegular(0);
std::atomic<int> receivedError(0);
IoBuffer data;
data.put('A');
data.put('B');
data.put('C');
data.put(0);

Command command;
command.command = mdp::Command::Subscribe;
command.topic = URI<STRICT>("https://localhost:8080/event");
command.data = std::move(data);
command.callback = [&receivedRegular, &receivedError](const mdp::Message &rep) {
fmt::print("SSE client received reply = '{}' - body size: '{}'\n", rep.data.asString(), rep.data.size());
if (rep.error.size() == 0) {
receivedRegular.fetch_add(1, std::memory_order_relaxed);
} else {
receivedError.fetch_add(1, std::memory_order_relaxed);
}
receivedRegular.notify_all();
receivedError.notify_all();
};

client.request(command);

std::cout << "client request launched" << std::endl;
std::this_thread::sleep_for(std::chrono::milliseconds(100));
eventDispatcher.send_event("test-event meta data");
std::jthread dispatcher([&updateCounter, &eventDispatcher] {
while (updateCounter < 5) {
std::this_thread::sleep_for(std::chrono::milliseconds(20));
eventDispatcher.send_event(fmt::format("test-event {}", updateCounter++));
}
});
dispatcher.join();

while (receivedRegular.load(std::memory_order_relaxed) < 5) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
std::cout << "done waiting" << std::endl;
REQUIRE(receivedRegular.load(std::memory_order_acquire) >= 5);

command.command = mdp::Command::Unsubscribe;
client.request(command);
std::this_thread::sleep_for(std::chrono::milliseconds(100));
std::cout << "done Unsubscribe" << std::endl;

client.stop();
server.stop();
eventDispatcher.send_event(fmt::format("test-event {}", updateCounter++));
std::cout << "server stopped" << std::endl;
}

} // namespace opencmw::rest_client_test

0 comments on commit ac69eaf

Please sign in to comment.