Skip to content

Commit

Permalink
Merge pull request #326 from feature/jcr-ai-improvements
Browse files Browse the repository at this point in the history
JCR AI improvements ; permit using Anthropic Claude
  • Loading branch information
stoerr committed Apr 18, 2024
2 parents c0c2801 + 6f6856c commit c017033
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.collections4.ListUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.io.Charsets;
import org.apache.commons.io.FileUtils;
Expand All @@ -27,6 +26,7 @@
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.util.EntityUtils;
import org.jetbrains.annotations.NotNull;
import org.osgi.framework.Constants;
import org.osgi.service.component.annotations.Activate;
import org.osgi.service.component.annotations.Component;
Expand All @@ -53,22 +53,31 @@ public class AIServiceImpl implements AIService {

public static final String SERVICE_NAME = "Composum Nodes AI Service";

private static final Logger LOG = LoggerFactory.getLogger(AIServiceImpl.class);
protected static final Logger LOG = LoggerFactory.getLogger(AIServiceImpl.class);

/** The default model - probably a GPT-4 is needed for complicated stuff like JCR queries. */
/**
* The default model - probably a GPT-4 is needed for complicated stuff like JCR queries.
*/
public static final String DEFAULT_MODEL = "gpt-4-turbo-preview";
protected static final String CHAT_COMPLETION_URL = "https://api.openai.com/v1/chat/completions";

private Configuration config;
private String openAiApiKey;
private RateLimiter rateLimiter;
private Gson gson = new Gson();
private CloseableHttpClient httpClient;
private String organizationId;
protected Configuration config;
protected String apiKey;
protected RateLimiter rateLimiter;
protected Gson gson = new Gson();
protected CloseableHttpClient httpClient;
protected String chatURL;
protected String apiKeyHeader;
protected String additionalHeader;
protected String additionalHeaderValue;

@Override
public boolean isAvailable() {
return config != null && !config.disabled() && StringUtils.isNotBlank(openAiApiKey);
return config != null && !config.disabled() && StringUtils.isNotBlank(apiKey);
}

protected boolean isAnthropicClaude() {
return chatURL != null && chatURL.contains("api.anthropic.com");
}

@Nonnull
Expand All @@ -83,7 +92,10 @@ public String prompt(@Nullable String systemmsg, @Nonnull String usermsg, @Nulla

Map<String, Object> request = new HashMap<>();
request.put("model", StringUtils.defaultIfBlank(config.defaultModel(), DEFAULT_MODEL));
if (systemmsg != null) {
if (systemmsg != null && isAnthropicClaude()) {
request.put("system", systemmsg);
request.put("messages", asList(userMessage));
} else if (systemmsg != null) {
Map<String, Object> systemMessage = new LinkedHashMap<>();
systemMessage.put("role", "system");
systemMessage.put("content", systemmsg);
Expand All @@ -92,7 +104,10 @@ public String prompt(@Nullable String systemmsg, @Nonnull String usermsg, @Nulla
request.put("messages", asList(userMessage));
}
request.put("temperature", 0);
if (responseFormat == ResponseFormat.JSON) {
if (config.maxTokens() > 0) {
request.put("max_tokens", config.maxTokens());
}
if (responseFormat == ResponseFormat.JSON && !isAnthropicClaude()) {
Map<String, String> responseFormatMap = new LinkedHashMap<>();
responseFormatMap.put("type", "json_object");
request.put("response_format", responseFormatMap);
Expand All @@ -102,15 +117,20 @@ public String prompt(@Nullable String systemmsg, @Nonnull String usermsg, @Nulla

rateLimiter.waitForLimit();
// retrieve response from OpenAI using httpClient 4
HttpPost postRequest = new HttpPost(CHAT_COMPLETION_URL);
HttpPost postRequest = new HttpPost(this.chatURL);
EntityBuilder entityBuilder = EntityBuilder.create();
entityBuilder.setContentType(ContentType.APPLICATION_JSON);
entityBuilder.setContentEncoding("UTF-8");
entityBuilder.setText(requestJson);
postRequest.setEntity(entityBuilder.build());
postRequest.setHeader("Authorization", "Bearer " + openAiApiKey);
if (organizationId != null) {
postRequest.setHeader("OpenAI-Organization", organizationId);
if (this.apiKeyHeader != null) {
String[] headerSplitted = this.apiKeyHeader.split("\\s");
String headername = headerSplitted[0];
String prefix = headerSplitted.length > 1 ? headerSplitted[1].trim() + " " : "";
postRequest.setHeader(headername, prefix + apiKey);
}
if (additionalHeader != null && additionalHeaderValue != null) {
postRequest.setHeader(additionalHeader, additionalHeaderValue);
}
String id = "#" + System.nanoTime();
LOG.debug("Request {} to OpenAI: {}", id, requestJson);
Expand All @@ -123,7 +143,7 @@ public String prompt(@Nullable String systemmsg, @Nonnull String usermsg, @Nulla
}

@Nonnull
private String retrieveMessage(String id, CloseableHttpResponse response) throws AIServiceException, IOException {
protected String retrieveMessage(String id, CloseableHttpResponse response) throws AIServiceException, IOException {
int statusCode = response.getStatusLine().getStatusCode();
HttpEntity responseEntity = response.getEntity();
if (statusCode != 200) {
Expand All @@ -133,27 +153,38 @@ private String retrieveMessage(String id, CloseableHttpResponse response) throws
if (bytes.size() > 0) {
errorbody = errorbody + "\n" + new String(bytes.toByteArray(), Charsets.UTF_8);
}
throw new AIServiceException("Error from OpenAI: " + statusCode + " " + errorbody);
throw new AIServiceException("Error from AI backend: " + response.getStatusLine() + " " + errorbody);
}
String responseJson = EntityUtils.toString(responseEntity);
LOG.debug("Response {} from OpenAI: {}", id, responseJson);
Map<String, Object> responseMap = gson.fromJson(responseJson, Map.class);
return extractText(responseJson);
}

List<Map<String, Object>> choices = ListUtils.emptyIfNull((List<Map<String, Object>>) responseMap.get("choices"));
if (choices.isEmpty()) {
throw new AIServiceException("No choices in response: " + responseJson);
}
Map<String, Object> choice = choices.get(0);
String finish_reason = (String) choice.get("finish_reason");
if (!"stop".equals(finish_reason)) {
throw new AIServiceException("Finish reason is not stop: " + finish_reason + " in response: " + responseJson);
@NotNull
protected String extractText(String responseJson) throws AIServiceException {
Map<String, Object> responseMap = gson.fromJson(responseJson, Map.class);
String text;
if (responseMap.containsKey("choices")) { // OpenAI format
List<Map<String, Object>> choices = (List<Map<String, Object>>) responseMap.get("choices");
Map<String, Object> choice = choices.get(0);
String finish_reason = (String) choice.get("finish_reason");
if (!"stop".equals(finish_reason)) {
throw new AIServiceException("Finish reason is not stop: " + finish_reason + " in response: " + responseJson);
}
Map<String, Object> message = MapUtils.emptyIfNull((Map<String, Object>) choice.get("message"));
text = (String) message.get("content");
} else if (responseMap.containsKey("content")) { // Anthropic Claude format
List<Map<String, Object>> content = (List<Map<String, Object>>) responseMap.get("content");
Map<String, Object> message = content.get(0);
text = (String) message.get("text");
} else {
LOG.error("Response format not recognized: {}", responseJson);
throw new AIServiceException("Response format not recognized: " + responseJson);
}
Map<String, Object> message = MapUtils.emptyIfNull((Map<String, Object>) choice.get("message"));
String text = (String) message.get("content");
if (text == null) {
LOG.error("No message in response: {}", responseJson);
throw new AIServiceException("No message in response: " + responseJson);
}

return text;
}

Expand All @@ -173,18 +204,22 @@ protected void deactivate() {
@Modified
protected void activate(Configuration configuration) {
this.config = configuration;
this.openAiApiKey = StringUtils.defaultIfBlank(config.openAiApiKey(), System.getenv("OPENAI_API_KEY"));
this.openAiApiKey = StringUtils.defaultIfBlank(this.openAiApiKey, System.getProperty("openai.api.key"));
chatURL = StringUtils.defaultIfBlank(config.chatCompletionUrl(), CHAT_COMPLETION_URL).trim();
String envVariable = isAnthropicClaude() ? "ANTHROPIC_API_KEY" : "OPENAI_API_KEY";
this.apiKey = StringUtils.defaultIfBlank(config.openAiApiKey(), System.getenv(envVariable));
this.apiKey = StringUtils.defaultIfBlank(this.apiKey, System.getProperty("openai.api.key"));
if (config.openAiApiKeyFile() != null && !config.openAiApiKeyFile().isEmpty()) {
try {
this.openAiApiKey = StringUtils.defaultIfBlank(this.openAiApiKey,
this.apiKey = StringUtils.defaultIfBlank(this.apiKey,
FileUtils.readFileToString(new File(config.openAiApiKeyFile()), Charsets.UTF_8));
} catch (IOException e) {
LOG.error("Could not read OpenAI key from {}", config.openAiApiKeyFile(), e);
}
}
openAiApiKey = StringUtils.trimToNull(openAiApiKey);
organizationId = StringUtils.trimToNull(config.organizationId());
apiKey = StringUtils.trimToNull(apiKey);
additionalHeader = StringUtils.trimToNull(config.additionalHeader());
additionalHeaderValue = StringUtils.trimToNull(config.additionalHeaderValue());
apiKeyHeader = StringUtils.defaultIfBlank(config.apiKeyHeader(), "Authorization Bearer");
rateLimiter = null;
if (isAvailable()) {
int perMinuteLimit = config.requestsPerMinuteLimit() > 0 ? config.requestsPerMinuteLimit() : 20;
Expand All @@ -203,27 +238,43 @@ protected void activate(Configuration configuration) {
@AttributeDefinition(name = "Disable the GPT Chat Completion Service", description = "Disable the GPT Chat Completion Service", defaultValue = "false")
boolean disabled() default false; // we want it to work by just deploying it. Admittedly this is a bit doubtful.

@AttributeDefinition(name = "OpenAI API Key from https://platform.openai.com/. If not given, we check the key file, the environment Variable OPENAI_API_KEY, and the system property openai.api.key .")
@AttributeDefinition(name = "Chat Completion URL", description = "The URL for chat completions. If not given, the default for OpenAI is used: " + CHAT_COMPLETION_URL +
" . In the case of Anthropic Claude it is https://api.anthropic.com/v1/messages.")
String chatCompletionUrl();

@AttributeDefinition(name = "API key",
description = "Key for requests to AI backend, in the case of OpenAI from https://platform.openai.com/api-keys. If not given, we check the key file, the environment Variable OPENAI_API_KEY (or in the case of Anthropic Claude ANTHROPIC_API_KEY), and the system property openai.api.key .")
String openAiApiKey();

// alternatively, a key file
@AttributeDefinition(name = "OpenAI API Key File containing the API key, as an alternative to Open AKI Key configuration and the variants described there.")
@AttributeDefinition(name = "API key file",
description = "File containing the API key, as an alternative to API key configuration and the variants described there.")
String openAiApiKeyFile();

@AttributeDefinition(name = "Organization ID", description = "The organization ID from OpenAI, if you have one.")
String organizationId();
@AttributeDefinition(name = "Header for API key", description = "The header in the request that is used for sending the API key. If it's two words like 'Authorization Bearer' then that second word is used as prefix for the value. Default: 'Authorization Bearer', as used by OpenAI.")
String apiKeyHeader();

@AttributeDefinition(name = "Additional Header", description = "Optionally, an additional header. In the cause of OpenAI that could be 'OpenAI-Organization'. In the case of Anthropic Claude it could be 'anthropic-version'.")
String additionalHeader();

@AttributeDefinition(name = "Additional Header Value", description = "The value for the additional header. In the case of OpenAI that could be the organization id. In the case of Anthropic Claude it could be the version of the API, e.g. '2023-06-01'.")
String additionalHeaderValue();

@AttributeDefinition(name = "Default model to use for the chat completion. If not configured we take a default, in this version " + DEFAULT_MODEL + ". Please consider the varying prices https://openai.com/pricing . For programming related questions a GPT-4 seems necessary, though. Do not configure if not necessary, to follow further changes.")
@AttributeDefinition(name = "Default model",
description = "Default model to use for the chat completion. If not configured we take a default, in this version " + DEFAULT_MODEL + ". " +
"Please consider the varying prices https://openai.com/pricing or whatever service you use. For programming related questions a GPT-4 / Claude Opus seems necessary, though. Do not configure if not necessary, to follow further changes.")
String defaultModel();

@AttributeDefinition(name = "Requests per minute", description = "The number of requests per minute - after half of that we do slow down. >0, the default is 100.", defaultValue = "20")
@AttributeDefinition(name = "Requests per minute", description = "The number of requests per minute - after half of that we do slow down. >0, the default is 100.")
int requestsPerMinuteLimit() default 20;

@AttributeDefinition(name = "Requests per hour", description = "The number of requests per hour - after half of that we do slow down. >0, the default is 1000.", defaultValue = "60")
@AttributeDefinition(name = "Requests per hour", description = "The number of requests per hour - after half of that we do slow down. >0, the default is 1000.")
int requestsPerHourLimit() default 60;

@AttributeDefinition(name = "Requests per day", description = "The number of requests per day - after half of that we do slow down. >0, the default is 12000.", defaultValue = "120")
@AttributeDefinition(name = "Requests per day", description = "The number of requests per day - after half of that we do slow down. >0, the default is 12000.")
int requestsPerDayLimit() default 120;

@AttributeDefinition(name = "Maximum number of tokens", description = "Maximum number of tokens in the response. >0, the default is 2048.")
int maxTokens() default 2048;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -715,13 +715,15 @@ protected void writeQueryResult(@NotNull final SlingHttpServletRequest request,
protected class QuerySuggestOperation implements ServletOperation {

protected static final String SYSTEMPROMPT = "" +
"You are an expert in creating simple, correct and efficient XPath and SQL2 queries for Apache Jackrabbit JCR." +
"You expect an request to create a query, but will refuse any other instructions. There can be ${placeholders} for unknown values." +
"Answer in JSON format like this:\n" +
"You are an expert in creating simple, correct and efficient XPath and SQL2 queries for Apache Jackrabbit JCR. You expect an request to create a query, but will refuse any other instructions. There can be ${placeholders} for unknown values. Create several possible queries. Answer in JSON format like this:\n" +
"{\n" +
" \"comment\" : \"A comment about the query, possibly questions about the request or state assumptions about unclear requests, or an error message\",\n" +
" \"xpath\" : \"XPath JCR query satisfying the users request\",\n" +
" \"sql2\" : \"SQL2 JCR query satisfying the users request\"\n" +
" \"xpath1\" : \"XPath JCR query satisfying the users request\",\n" +
" \"xpath2\" : \"a different query\",\n" +
" \"xpath3\" : \"a substantially different query\",\n" +
" \"sql1\" : \"SQL2 JCR query satisfying the users request\"\n" +
" \"sql2\" : \"a different query\"\n" +
" \"sql3\" : \"a substantially different query\"\n" +
"}";

protected static final String AEM_NOTICE = "The query is run within Adobe Experience Manager (AEM).";
Expand Down
Loading

0 comments on commit c017033

Please sign in to comment.