Skip to content

Commit

Permalink
Enhance code clarity and functionality for chat message processing (#210
Browse files Browse the repository at this point in the history
)

* feat: Add EnhancedCompressingQueryTransformer class

This commit adds a new class called EnhancedCompressingQueryTransformer to the codebase. This class extends the CompressingQueryTransformer class and overrides the format method to provide a custom formatting for ChatMessage objects. The format method now checks if the message is a UserMessage or an AiMessage and formats it accordingly.

* Fix SC issues

* Fix SC issues

* Improvements by @coderabbitai

* Sc improvements

* Using SonarQube Gradle Plugin

* chore: Remove SonarQube Gradle Plugin from build.gradle

* ChatMessageUtil : Improvements from @coderabbitai
  • Loading branch information
amengus87 authored Aug 4, 2024
1 parent 61e6918 commit 5939170
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public class ErrorHandlerController implements ErrorController {
// We are using a SPA (Single Page Application) so we need to output the
// index.html file when an error occurs :
@RequestMapping("/error")
public @ResponseBody byte[] getImage() throws IOException {
public @ResponseBody byte[] getSpaContent() throws IOException {
InputStream in = getClass().getResourceAsStream("/static/index.html");
if (in == null) {
throw new IOException("index.html not found");
Expand Down
119 changes: 0 additions & 119 deletions backend/src/main/java/ai/dragon/service/ChatMessageService.java

This file was deleted.

24 changes: 12 additions & 12 deletions backend/src/main/java/ai/dragon/service/RaagService.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
import ai.dragon.properties.raag.RetrievalAugmentorSettings;
import ai.dragon.repository.FarmRepository;
import ai.dragon.repository.SiloRepository;
import ai.dragon.util.ChatMessageUtil;
import ai.dragon.util.KVSettingUtil;
import ai.dragon.util.ai.AiAssistant;
import ai.dragon.util.spel.MetadataHeaderFilterExpressionParserUtil;
import ai.dragon.util.transformer.EnhancedCompressingQueryTransformer;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.ChatMemory;
Expand Down Expand Up @@ -65,9 +67,6 @@ public class RaagService {
@Autowired
private SiloRepository siloRepository;

@Autowired
private ChatMessageService chatMessageService;

@Autowired
private OpenAiCompletionService openAiCompletionService;

Expand Down Expand Up @@ -147,14 +146,14 @@ private OpenAiCompletionResponse completionResponse(FarmEntity farm, OpenAiCompl
HttpServletRequest servletRequest)
throws Exception {
AiAssistant assistant = this.makeCompletionAssistant(farm, completionRequest, servletRequest, false);
Result<String> answer = assistant.answer(chatMessageService.singleTextFrom(completionRequest));
Result<String> answer = assistant.answer(ChatMessageUtil.singleTextFrom(completionRequest));
return openAiCompletionService.createCompletionResponse(completionRequest, answer);
}

private SseEmitter streamCompletionResponse(FarmEntity farm, OpenAiCompletionRequest completionRequest,
HttpServletRequest servletRequest) throws Exception {
AiAssistant assistant = this.makeCompletionAssistant(farm, completionRequest, servletRequest, true);
TokenStream stream = assistant.chat(chatMessageService.singleTextFrom(completionRequest));
TokenStream stream = assistant.chat(ChatMessageUtil.singleTextFrom(completionRequest));

Check warning on line 156 in backend/src/main/java/ai/dragon/service/RaagService.java

View check run for this annotation

Codecov / codecov/patch

backend/src/main/java/ai/dragon/service/RaagService.java#L156

Added line #L156 was not covered by tests
UUID emitterID = sseService.createEmitter();
stream
.onNext(nextChunk -> {
Expand All @@ -181,9 +180,9 @@ private OpenAiChatCompletionResponse chatCompletionResponse(FarmEntity farm,
AiAssistant assistant = this.makeChatAssistant(farm, chatCompletionRequest, servletRequest, false);
OpenAiCompletionMessage lastCompletionMessage = chatCompletionRequest.getMessages()
.get(chatCompletionRequest.getMessages().size() - 1);
ChatMessage lastChatMessage = chatMessageService.convertToChatMessage(lastCompletionMessage)
ChatMessage lastChatMessage = ChatMessageUtil.convertToChatMessage(lastCompletionMessage)
.orElseThrow();
Result<String> answer = assistant.answer(chatMessageService.singleTextFrom(lastChatMessage));
Result<String> answer = assistant.answer(ChatMessageUtil.singleTextFrom(lastChatMessage));
return openAiCompletionService.createChatCompletionResponse(answer);
}

Expand All @@ -193,9 +192,9 @@ private SseEmitter streamChatCompletionResponse(FarmEntity farm, OpenAiChatCompl
AiAssistant assistant = this.makeChatAssistant(farm, chatCompletionRequest, servletRequest, true);
OpenAiCompletionMessage lastCompletionMessage = chatCompletionRequest.getMessages()
.get(chatCompletionRequest.getMessages().size() - 1);
ChatMessage lastChatMessage = chatMessageService.convertToChatMessage(lastCompletionMessage)
ChatMessage lastChatMessage = ChatMessageUtil.convertToChatMessage(lastCompletionMessage)

Check warning on line 195 in backend/src/main/java/ai/dragon/service/RaagService.java

View check run for this annotation

Codecov / codecov/patch

backend/src/main/java/ai/dragon/service/RaagService.java#L195

Added line #L195 was not covered by tests
.orElseThrow();
TokenStream stream = assistant.chat(chatMessageService.singleTextFrom(lastChatMessage));
TokenStream stream = assistant.chat(ChatMessageUtil.singleTextFrom(lastChatMessage));

Check warning on line 197 in backend/src/main/java/ai/dragon/service/RaagService.java

View check run for this annotation

Codecov / codecov/patch

backend/src/main/java/ai/dragon/service/RaagService.java#L197

Added line #L197 was not covered by tests
UUID emitterID = sseService.createEmitter();
stream
.onNext(nextChunk -> {
Expand Down Expand Up @@ -260,7 +259,7 @@ private ChatMemory buildChatMemory(FarmEntity farm, OpenAiChatCompletionRequest
.apply(retrievalSettings);
for (int i = 0; i < request.getMessages().size(); i++) {
OpenAiCompletionMessage requestMessage = request.getMessages().get(i);
chatMessageService.convertToChatMessage(requestMessage).ifPresent(memory::add);
ChatMessageUtil.convertToChatMessage(requestMessage).ifPresent(memory::add);
}
return memory;
}
Expand Down Expand Up @@ -309,8 +308,9 @@ private void buildRetrievalAugmentor(AiServices<AiAssistant> assistantBuilder, F
&& openAiRequest instanceof OpenAiChatCompletionRequest) {
// Query Rewriting => Improve RAG Performance and Accuracy
// => Uses Chat History.
retrievalAugmentorBuilder.queryTransformer(CompressingQueryTransformer.builder()
.chatLanguageModel(this.buildChatLanguageModel(farm, openAiRequest)).build());
CompressingQueryTransformer compressingQueryTransformer = new EnhancedCompressingQueryTransformer(
this.buildChatLanguageModel(farm, openAiRequest));
retrievalAugmentorBuilder.queryTransformer(compressingQueryTransformer);
}
assistantBuilder.retrievalAugmentor(retrievalAugmentorBuilder.build());
}
Expand Down
115 changes: 115 additions & 0 deletions backend/src/main/java/ai/dragon/util/ChatMessageUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package ai.dragon.util;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ai.dragon.dto.openai.completion.OpenAiCompletionMessage;
import ai.dragon.dto.openai.completion.OpenAiCompletionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.UserMessage;

public class ChatMessageUtil {
private static final Logger LOGGER = LoggerFactory.getLogger(ChatMessageUtil.class);

private ChatMessageUtil() {
}

public static String singleTextFrom(ChatMessage message) {
StringBuilder sb = new StringBuilder();
if (message instanceof UserMessage userMessage) {
userMessage.contents().forEach(content -> {
if (content instanceof TextContent textContent) {
sb.append(textContent.text());
}
});
} else if (message instanceof AiMessage aiMessage) {
sb.append(aiMessage.text());

Check warning on line 37 in backend/src/main/java/ai/dragon/util/ChatMessageUtil.java

View check run for this annotation

Codecov / codecov/patch

backend/src/main/java/ai/dragon/util/ChatMessageUtil.java#L37

Added line #L37 was not covered by tests
} else if (message instanceof SystemMessage systemMessage) {
sb.append(systemMessage.text());

Check warning on line 39 in backend/src/main/java/ai/dragon/util/ChatMessageUtil.java

View check run for this annotation

Codecov / codecov/patch

backend/src/main/java/ai/dragon/util/ChatMessageUtil.java#L39

Added line #L39 was not covered by tests
}
return sb.toString();
}

public static String singleTextFrom(OpenAiCompletionRequest request) {
Object prompt = request.getPrompt();
if (prompt instanceof String stringPrompt) {
return stringPrompt;
} else if (prompt instanceof List<?> listPrompt) {
return listPrompt.stream()
.map(Object::toString)
.collect(Collectors.joining());

Check warning on line 51 in backend/src/main/java/ai/dragon/util/ChatMessageUtil.java

View check run for this annotation

Codecov / codecov/patch

backend/src/main/java/ai/dragon/util/ChatMessageUtil.java#L49-L51

Added lines #L49 - L51 were not covered by tests
}
return null;

Check warning on line 53 in backend/src/main/java/ai/dragon/util/ChatMessageUtil.java

View check run for this annotation

Codecov / codecov/patch

backend/src/main/java/ai/dragon/util/ChatMessageUtil.java#L53

Added line #L53 was not covered by tests
}

@SuppressWarnings("unchecked")
public static Optional<ChatMessage> convertToChatMessage(OpenAiCompletionMessage completionMessage) {
ChatMessage chatMessage = switch (completionMessage.getRole()) {
case "user" -> {
if (completionMessage.getContent() instanceof String stringContent) {
yield new UserMessage(stringContent);
} else {
List<Content> contents = contentsListFrom(
(List<Map<String, Object>>) completionMessage.getContent());
yield new UserMessage(contents);
}
}
case "system" -> new SystemMessage((String) completionMessage.getContent());
case "assistant" -> new AiMessage((String) completionMessage.getContent());

Check warning on line 69 in backend/src/main/java/ai/dragon/util/ChatMessageUtil.java

View check run for this annotation

Codecov / codecov/patch

backend/src/main/java/ai/dragon/util/ChatMessageUtil.java#L69

Added line #L69 was not covered by tests
default -> {
LOGGER.error(String.format("Invalid Message Role '%s'", completionMessage.getRole()));

Check warning on line 71 in backend/src/main/java/ai/dragon/util/ChatMessageUtil.java

View check run for this annotation

Codecov / codecov/patch

backend/src/main/java/ai/dragon/util/ChatMessageUtil.java#L71

Added line #L71 was not covered by tests
yield null;
}
};
return Optional.ofNullable(chatMessage);
}

private static List<Content> contentsListFrom(List<Map<String, Object>> content) {
List<Content> contents = new ArrayList<>();
content.forEach(contentItem -> {
String type = (String) contentItem.get("type");
if (type == null) {
LOGGER.error("Content part must have a type field!");
return;

Check warning on line 84 in backend/src/main/java/ai/dragon/util/ChatMessageUtil.java

View check run for this annotation

Codecov / codecov/patch

backend/src/main/java/ai/dragon/util/ChatMessageUtil.java#L83-L84

Added lines #L83 - L84 were not covered by tests
}
contents.add(createContent(type, contentItem));
});
return contents;
}

@SuppressWarnings("unchecked")
private static Content createContent(String type, Map<String, Object> contentItem) {
switch (type) {
case "text":
return new TextContent((String) contentItem.get("text"));
case "image_url":
return createImageContent((Map<String, Object>) contentItem.get("image_url"));
default:
return null;

Check warning on line 99 in backend/src/main/java/ai/dragon/util/ChatMessageUtil.java

View check run for this annotation

Codecov / codecov/patch

backend/src/main/java/ai/dragon/util/ChatMessageUtil.java#L99

Added line #L99 was not covered by tests
}
}

private static Content createImageContent(Map<String, Object> imageURL) {
String url = (String) imageURL.get("url");
if (url.startsWith("http")) {
return new ImageContent(url);

Check warning on line 106 in backend/src/main/java/ai/dragon/util/ChatMessageUtil.java

View check run for this annotation

Codecov / codecov/patch

backend/src/main/java/ai/dragon/util/ChatMessageUtil.java#L106

Added line #L106 was not covered by tests
} else if (url.startsWith("data:")) {
String mimetype = DataUrlUtil.getImageType(url);
String base64String = DataUrlUtil.getDataBytesString(url);
return ImageContent.from(base64String, mimetype);
}
// TODO ImageURL.detail
return null;

Check warning on line 113 in backend/src/main/java/ai/dragon/util/ChatMessageUtil.java

View check run for this annotation

Codecov / codecov/patch

backend/src/main/java/ai/dragon/util/ChatMessageUtil.java#L113

Added line #L113 was not covered by tests
}
}
22 changes: 22 additions & 0 deletions backend/src/main/java/ai/dragon/util/DataUrlUtil.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package ai.dragon.util;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.util.Base64;

public class DataUrlUtil {
Expand All @@ -20,6 +24,24 @@ public static String getImageType(String base64String) {
return getImageType(getDataBytes(getDataBytesString(base64String)));
}

public static String convertFileToDataImageBase64(File file) throws IOException {
String mimeType = Files.probeContentType(file.toPath());
return convertFileToDataImageBase64(file, mimeType);
}

public static String convertFileToDataImageBase64(File file, String mimeType) throws IOException {
// Read the file into a byte array
byte[] fileBytes;
try (FileInputStream fileInputStream = new FileInputStream(file)) {
fileBytes = new byte[(int) file.length()];
fileInputStream.read(fileBytes);
}
// Encode the byte array to base64
String base64Encoded = Base64.getEncoder().encodeToString(fileBytes);
// Construct the data:image base64 string
return "data:" + mimeType + ";base64," + base64Encoded;
}

private static byte[] getDataBytes(String base64String) {
return Base64.getDecoder().decode(base64String);
}
Expand Down
Loading

0 comments on commit 5939170

Please sign in to comment.