Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[client-v2] Added implementation for Bearer token auth #1904

Merged
merged 7 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions client-v2/src/main/java/com/clickhouse/client/api/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@
* ...
* }
* }
*
* }
*
*
Expand All @@ -131,6 +130,9 @@ public class Client implements AutoCloseable {

private final Set<String> endpoints;
private final Map<String, String> configuration;

private final Map<String, String> readOnlyConfig;

private final List<ClickHouseNode> serverNodes = new ArrayList<>();

// POJO serializer mapping (class -> (schema -> (format -> serializer)))
Expand All @@ -157,6 +159,7 @@ private Client(Set<String> endpoints, Map<String,String> configuration, boolean
ExecutorService sharedOperationExecutor, ColumnToMethodMatchingStrategy columnToMethodMatchingStrategy) {
this.endpoints = endpoints;
this.configuration = configuration;
this.readOnlyConfig = Collections.unmodifiableMap(this.configuration);
this.endpoints.forEach(endpoint -> {
this.serverNodes.add(ClickHouseNode.of(endpoint, this.configuration));
});
Expand Down Expand Up @@ -852,7 +855,7 @@ public Builder allowBinaryReaderToReuseBuffers(boolean reuse) {
* @return same instance of the builder
*/
public Builder httpHeader(String key, String value) {
this.configuration.put(ClientConfigProperties.HTTP_HEADER_PREFIX + key.toUpperCase(Locale.US), value);
this.configuration.put(ClientConfigProperties.httpHeader(key), value);
return this;
}

Expand All @@ -863,7 +866,7 @@ public Builder httpHeader(String key, String value) {
* @return same instance of the builder
*/
public Builder httpHeader(String key, Collection<String> values) {
this.configuration.put(ClientConfigProperties.HTTP_HEADER_PREFIX + key.toUpperCase(Locale.US), ClientConfigProperties.commaSeparated(values));
this.configuration.put(ClientConfigProperties.httpHeader(key), ClientConfigProperties.commaSeparated(values));
return this;
}

Expand Down Expand Up @@ -954,6 +957,19 @@ public Builder setOptions(Map<String, String> options) {
return this;
}

/**
* Specifies whether to use Bearer Authentication and what token to use.
* The token will be sent as is, so it should be encoded before passing to this method.
*
* @param bearerToken - token to use
* @return same instance of the builder
*/
public Builder useBearerTokenAuth(String bearerToken) {
// Most JWT libraries (https://jwt.io/libraries?language=Java) compact tokens in proper way
this.httpHeader(HttpHeaders.AUTHORIZATION, "Bearer " + bearerToken);
return this;
}

public Client build() {
setDefaults();

Expand All @@ -964,8 +980,9 @@ public Client build() {
// check if username and password are empty. so can not initiate client?
if (!this.configuration.containsKey("access_token") &&
(!this.configuration.containsKey("user") || !this.configuration.containsKey("password")) &&
!MapUtils.getFlag(this.configuration, "ssl_authentication", false)) {
throw new IllegalArgumentException("Username and password (or access token, or SSL authentication) are required");
!MapUtils.getFlag(this.configuration, "ssl_authentication", false) &&
!this.configuration.containsKey(ClientConfigProperties.httpHeader(HttpHeaders.AUTHORIZATION))) {
throw new IllegalArgumentException("Username and password (or access token or SSL authentication or pre-define Authorization header) are required");
}

if (this.configuration.containsKey("ssl_authentication") &&
Expand Down Expand Up @@ -1011,7 +1028,8 @@ public Client build() {
throw new IllegalArgumentException("Nor server timezone nor specific timezone is set");
}

return new Client(this.endpoints, this.configuration, this.useNewImplementation, this.sharedOperationExecutor, this.columnToMethodMatchingStrategy);
return new Client(this.endpoints, this.configuration, this.useNewImplementation, this.sharedOperationExecutor,
this.columnToMethodMatchingStrategy);
}

private static final int DEFAULT_NETWORK_BUFFER_SIZE = 300_000;
Expand Down Expand Up @@ -2103,7 +2121,7 @@ public String toString() {
* @return - configuration options
*/
public Map<String, String> getConfiguration() {
return Collections.unmodifiableMap(configuration);
return readOnlyConfig;
}

/** Returns operation timeout in seconds */
Expand Down Expand Up @@ -2150,6 +2168,10 @@ public Collection<String> getDBRoles() {
return unmodifiableDbRolesView;
}

public void updateBearerToken(String bearer) {
this.configuration.put(ClientConfigProperties.httpHeader(HttpHeaders.AUTHORIZATION), "Bearer " + bearer);
}

private ClickHouseNode getNextAliveNode() {
return serverNodes.get(0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -157,6 +158,10 @@ public static String serverSetting(String key) {
return SERVER_SETTING_PREFIX + key;
}

public static String httpHeader(String key) {
return HTTP_HEADER_PREFIX + key.toUpperCase(Locale.US);
}

public static String commaSeparated(Collection<?> values) {
StringBuilder sb = new StringBuilder();
for (Object value : values) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,34 @@ public class ServerException extends RuntimeException {
public static final int TABLE_NOT_FOUND = 60;

private final int code;

private final int transportProtocolCode;

public ServerException(int code, String message) {
this(code, message, 500);
}

public ServerException(int code, String message, int transportProtocolCode) {
super(message);
this.code = code;
this.transportProtocolCode = transportProtocolCode;
}

/**
* Returns CH server error code. May return 0 if code is unknown.
* @return - error code from server response
*/
public int getCode() {
return code;
}

/**
* Returns error code of underlying transport protocol. For example, HTTP status.
* By default, will return {@code 500 } what is derived from HTTP Server Internal Error.
*
* @return - transport status code
*/
public int getTransportProtocolCode() {
return transportProtocolCode;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Supplier;

import static com.clickhouse.client.api.ClientConfigProperties.SOCKET_TCP_NO_DELAY_OPT;

Expand Down Expand Up @@ -335,10 +336,13 @@ public Exception readError(ClassicHttpResponse httpResponse) {

String msg = msgBuilder.toString().replaceAll("\\s+", " ").replaceAll("\\\\n", " ")
.replaceAll("\\\\/", "/");
return new ServerException(serverCode, msg);
if (msg.trim().isEmpty()) {
msg = String.format(ERROR_CODE_PREFIX_PATTERN, serverCode) + " <Unreadable error message> (transport error: " + httpResponse.getCode() + ")";
}
return new ServerException(serverCode, msg, httpResponse.getCode());
} catch (Exception e) {
LOG.error("Failed to read error message", e);
return new ServerException(serverCode, String.format(ERROR_CODE_PREFIX_PATTERN, serverCode) + " <Unreadable error message>");
return new ServerException(serverCode, String.format(ERROR_CODE_PREFIX_PATTERN, serverCode) + " <Unreadable error message> (transport error: " + httpResponse.getCode() + ")", httpResponse.getCode());
}
}

Expand Down Expand Up @@ -450,12 +454,12 @@ private void addHeaders(HttpPost req, Map<String, String> chConfig, Map<String,

for (Map.Entry<String, String> entry : chConfig.entrySet()) {
if (entry.getKey().startsWith(ClientConfigProperties.HTTP_HEADER_PREFIX)) {
req.addHeader(entry.getKey().substring(ClientConfigProperties.HTTP_HEADER_PREFIX.length()), entry.getValue());
req.setHeader(entry.getKey().substring(ClientConfigProperties.HTTP_HEADER_PREFIX.length()), entry.getValue());
}
}
for (Map.Entry<String, Object> entry : requestConfig.entrySet()) {
if (entry.getKey().startsWith(ClientConfigProperties.HTTP_HEADER_PREFIX)) {
req.addHeader(entry.getKey().substring(ClientConfigProperties.HTTP_HEADER_PREFIX.length()), entry.getValue().toString());
req.setHeader(entry.getKey().substring(ClientConfigProperties.HTTP_HEADER_PREFIX.length()), entry.getValue().toString());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.io.ByteArrayInputStream;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Base64;
Expand All @@ -51,8 +52,11 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Pattern;
import java.util.function.Supplier;

import static com.github.tomakehurst.wiremock.stubbing.Scenario.STARTED;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.fail;

public class HttpTransportTests extends BaseIntegrationTest {
Expand All @@ -66,7 +70,6 @@ public void testConnectionTTL(Long connectionTtl, Long keepAlive, int openSocket
ClickHouseNode server = getServer(ClickHouseProtocol.HTTP);

int proxyPort = new Random().nextInt(1000) + 10000;
System.out.println("proxyPort: " + proxyPort);
ConnectionCounterListener connectionCounter = new ConnectionCounterListener();
WireMockServer proxy = new WireMockServer(WireMockConfiguration
.options().port(proxyPort)
Expand Down Expand Up @@ -154,7 +157,6 @@ public void closed(Socket socket) {
public void testConnectionRequestTimeout() {

int serverPort = new Random().nextInt(1000) + 10000;
System.out.println("proxyPort: " + serverPort);
ConnectionCounterListener connectionCounter = new ConnectionCounterListener();
WireMockServer proxy = new WireMockServer(WireMockConfiguration
.options().port(serverPort)
Expand Down Expand Up @@ -794,4 +796,71 @@ public static Object[][] testUserAgentHasCompleteProductName_dataProvider() {
{ "test-client/1.0", Pattern.compile("test-client/1.0 clickhouse-java-v2\\/.+ \\(.+\\) Apache HttpClient\\/[\\d\\.]+$")},
{ "test-client/", Pattern.compile("test-client/ clickhouse-java-v2\\/.+ \\(.+\\) Apache HttpClient\\/[\\d\\.]+$")}};
}

@Test(groups = { "integration" })
public void testBearerTokenAuth() throws Exception {
WireMockServer mockServer = new WireMockServer( WireMockConfiguration
.options().port(9090).notifier(new ConsoleNotifier(false)));
mockServer.start();

String jwtToken1 = Arrays.stream(
new String[]{"header", "payload", "signature"})
.map(s -> Base64.getEncoder().encodeToString(s.getBytes(StandardCharsets.UTF_8)))
.reduce((s1, s2) -> s1 + "." + s2).get();
try (Client client = new Client.Builder().addEndpoint(Protocol.HTTP, "localhost", mockServer.port(), false)
.useBearerTokenAuth(jwtToken1)
.compressServerResponse(false)
.build()) {

mockServer.addStubMapping(WireMock.post(WireMock.anyUrl())
.withHeader("Authorization", WireMock.equalTo("Bearer " + jwtToken1))
.willReturn(WireMock.aResponse()
.withHeader("X-ClickHouse-Summary",
"{ \"read_bytes\": \"10\", \"read_rows\": \"1\"}")).build());

try (QueryResponse response = client.query("SELECT 1").get(1, TimeUnit.SECONDS)) {
Assert.assertEquals(response.getReadBytes(), 10);
} catch (Exception e) {
Assert.fail("Unexpected exception", e);
}
}

String jwtToken2 = Arrays.stream(
new String[]{"header2", "payload2", "signature2"})
.map(s -> Base64.getEncoder().encodeToString(s.getBytes(StandardCharsets.UTF_8)))
.reduce((s1, s2) -> s1 + "." + s2).get();

mockServer.resetAll();
mockServer.addStubMapping(WireMock.post(WireMock.anyUrl())
.withHeader("Authorization", WireMock.equalTo("Bearer " + jwtToken1))
.willReturn(WireMock.aResponse()
.withStatus(HttpStatus.SC_UNAUTHORIZED))
.build());

try (Client client = new Client.Builder().addEndpoint(Protocol.HTTP, "localhost", mockServer.port(), false)
.useBearerTokenAuth(jwtToken1)
.compressServerResponse(false)
.build()) {

try {
client.execute("SELECT 1").get();
fail("Exception expected");
} catch (ServerException e) {
Assert.assertEquals(e.getTransportProtocolCode(), HttpStatus.SC_UNAUTHORIZED);
}

mockServer.resetAll();
mockServer.addStubMapping(WireMock.post(WireMock.anyUrl())
.withHeader("Authorization", WireMock.equalTo("Bearer " + jwtToken2))
.willReturn(WireMock.aResponse()
.withHeader("X-ClickHouse-Summary",
"{ \"read_bytes\": \"10\", \"read_rows\": \"1\"}"))

.build());

client.updateBearerToken(jwtToken2);

client.execute("SELECT 1").get();
}
}
}
Loading