Skip to content

Commit

Permalink
Merge pull request #216 from dRAGon-Okinawa/staging
Browse files Browse the repository at this point in the history
Enhance code clarity and functionality for chat message processing (#…
  • Loading branch information
isontheline authored Aug 7, 2024
2 parents e7cbd5e + a6cfa4f commit 3405c84
Show file tree
Hide file tree
Showing 9 changed files with 253 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public class ErrorHandlerController implements ErrorController {
// We are using a SPA (Single Page Application) so we need to output the
// index.html file when an error occurs :
@RequestMapping("/error")
public @ResponseBody byte[] getImage() throws IOException {
public @ResponseBody byte[] getSpaContent() throws IOException {
InputStream in = getClass().getResourceAsStream("/static/index.html");
if (in == null) {
throw new IOException("index.html not found");
Expand Down
119 changes: 0 additions & 119 deletions backend/src/main/java/ai/dragon/service/ChatMessageService.java

This file was deleted.

24 changes: 12 additions & 12 deletions backend/src/main/java/ai/dragon/service/RaagService.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
import ai.dragon.properties.raag.RetrievalAugmentorSettings;
import ai.dragon.repository.FarmRepository;
import ai.dragon.repository.SiloRepository;
import ai.dragon.util.ChatMessageUtil;
import ai.dragon.util.KVSettingUtil;
import ai.dragon.util.ai.AiAssistant;
import ai.dragon.util.spel.MetadataHeaderFilterExpressionParserUtil;
import ai.dragon.util.transformer.EnhancedCompressingQueryTransformer;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.ChatMemory;
Expand Down Expand Up @@ -65,9 +67,6 @@ public class RaagService {
@Autowired
private SiloRepository siloRepository;

@Autowired
private ChatMessageService chatMessageService;

@Autowired
private OpenAiCompletionService openAiCompletionService;

Expand Down Expand Up @@ -147,14 +146,14 @@ private OpenAiCompletionResponse completionResponse(FarmEntity farm, OpenAiCompl
HttpServletRequest servletRequest)
throws Exception {
AiAssistant assistant = this.makeCompletionAssistant(farm, completionRequest, servletRequest, false);
Result<String> answer = assistant.answer(chatMessageService.singleTextFrom(completionRequest));
Result<String> answer = assistant.answer(ChatMessageUtil.singleTextFrom(completionRequest));
return openAiCompletionService.createCompletionResponse(completionRequest, answer);
}

private SseEmitter streamCompletionResponse(FarmEntity farm, OpenAiCompletionRequest completionRequest,
HttpServletRequest servletRequest) throws Exception {
AiAssistant assistant = this.makeCompletionAssistant(farm, completionRequest, servletRequest, true);
TokenStream stream = assistant.chat(chatMessageService.singleTextFrom(completionRequest));
TokenStream stream = assistant.chat(ChatMessageUtil.singleTextFrom(completionRequest));
UUID emitterID = sseService.createEmitter();
stream
.onNext(nextChunk -> {
Expand All @@ -181,9 +180,9 @@ private OpenAiChatCompletionResponse chatCompletionResponse(FarmEntity farm,
AiAssistant assistant = this.makeChatAssistant(farm, chatCompletionRequest, servletRequest, false);
OpenAiCompletionMessage lastCompletionMessage = chatCompletionRequest.getMessages()
.get(chatCompletionRequest.getMessages().size() - 1);
ChatMessage lastChatMessage = chatMessageService.convertToChatMessage(lastCompletionMessage)
ChatMessage lastChatMessage = ChatMessageUtil.convertToChatMessage(lastCompletionMessage)
.orElseThrow();
Result<String> answer = assistant.answer(chatMessageService.singleTextFrom(lastChatMessage));
Result<String> answer = assistant.answer(ChatMessageUtil.singleTextFrom(lastChatMessage));
return openAiCompletionService.createChatCompletionResponse(answer);
}

Expand All @@ -193,9 +192,9 @@ private SseEmitter streamChatCompletionResponse(FarmEntity farm, OpenAiChatCompl
AiAssistant assistant = this.makeChatAssistant(farm, chatCompletionRequest, servletRequest, true);
OpenAiCompletionMessage lastCompletionMessage = chatCompletionRequest.getMessages()
.get(chatCompletionRequest.getMessages().size() - 1);
ChatMessage lastChatMessage = chatMessageService.convertToChatMessage(lastCompletionMessage)
ChatMessage lastChatMessage = ChatMessageUtil.convertToChatMessage(lastCompletionMessage)
.orElseThrow();
TokenStream stream = assistant.chat(chatMessageService.singleTextFrom(lastChatMessage));
TokenStream stream = assistant.chat(ChatMessageUtil.singleTextFrom(lastChatMessage));
UUID emitterID = sseService.createEmitter();
stream
.onNext(nextChunk -> {
Expand Down Expand Up @@ -260,7 +259,7 @@ private ChatMemory buildChatMemory(FarmEntity farm, OpenAiChatCompletionRequest
.apply(retrievalSettings);
for (int i = 0; i < request.getMessages().size(); i++) {
OpenAiCompletionMessage requestMessage = request.getMessages().get(i);
chatMessageService.convertToChatMessage(requestMessage).ifPresent(memory::add);
ChatMessageUtil.convertToChatMessage(requestMessage).ifPresent(memory::add);
}
return memory;
}
Expand Down Expand Up @@ -309,8 +308,9 @@ private void buildRetrievalAugmentor(AiServices<AiAssistant> assistantBuilder, F
&& openAiRequest instanceof OpenAiChatCompletionRequest) {
// Query Rewriting => Improve RAG Performance and Accuracy
// => Uses Chat History.
retrievalAugmentorBuilder.queryTransformer(CompressingQueryTransformer.builder()
.chatLanguageModel(this.buildChatLanguageModel(farm, openAiRequest)).build());
CompressingQueryTransformer compressingQueryTransformer = new EnhancedCompressingQueryTransformer(
this.buildChatLanguageModel(farm, openAiRequest));
retrievalAugmentorBuilder.queryTransformer(compressingQueryTransformer);
}
assistantBuilder.retrievalAugmentor(retrievalAugmentorBuilder.build());
}
Expand Down
117 changes: 117 additions & 0 deletions backend/src/main/java/ai/dragon/util/ChatMessageUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package ai.dragon.util;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ai.dragon.dto.openai.completion.OpenAiCompletionMessage;
import ai.dragon.dto.openai.completion.OpenAiCompletionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.UserMessage;

public class ChatMessageUtil {
private static final Logger LOGGER = LoggerFactory.getLogger(ChatMessageUtil.class);

private ChatMessageUtil() {
}

public static String singleTextFrom(ChatMessage message) {
StringBuilder sb = new StringBuilder();
if (message instanceof UserMessage userMessage) {
userMessage.contents().forEach(content -> {
if (content instanceof TextContent textContent) {
sb.append(textContent.text());
}
});
} else if (message instanceof AiMessage aiMessage) {
sb.append(aiMessage.text());
} else if (message instanceof SystemMessage systemMessage) {
sb.append(systemMessage.text());
}
return sb.toString();
}

public static String singleTextFrom(OpenAiCompletionRequest request) {
Object prompt = request.getPrompt();
if (prompt instanceof String stringPrompt) {
return stringPrompt;
} else if (prompt instanceof List<?> listPrompt) {
return listPrompt.stream()
.map(Object::toString)
.collect(Collectors.joining());
}
return null;
}

@SuppressWarnings("unchecked")
public static Optional<ChatMessage> convertToChatMessage(OpenAiCompletionMessage completionMessage) {
ChatMessage chatMessage = switch (completionMessage.getRole()) {
case "user" -> {
if (completionMessage.getContent() instanceof String stringContent) {
yield new UserMessage(stringContent);
} else {
List<Content> contents = contentsListFrom(
(List<Map<String, Object>>) completionMessage.getContent());
yield new UserMessage(contents);
}
}
case "system" -> new SystemMessage((String) completionMessage.getContent());
case "assistant" -> new AiMessage((String) completionMessage.getContent());
default -> {
LOGGER.error(String.format("Invalid Message Role '%s'", completionMessage.getRole()));
yield null;
}
};
return Optional.ofNullable(chatMessage);
}

private static List<Content> contentsListFrom(List<Map<String, Object>> content) {
List<Content> contents = new ArrayList<>();
if (content != null) {
content.forEach(contentItem -> {
String type = (String) contentItem.get("type");
if (type == null) {
LOGGER.error("Content part must have a type field!");
return;
}
contents.add(createContent(type, contentItem));
});
}
return contents;
}

@SuppressWarnings("unchecked")
private static Content createContent(String type, Map<String, Object> contentItem) {
switch (type) {
case "text":
return new TextContent((String) contentItem.get("text"));
case "image_url":
return createImageContent((Map<String, Object>) contentItem.get("image_url"));
default:
return null;
}
}

private static Content createImageContent(Map<String, Object> imageURL) {
String url = (String) imageURL.get("url");
if (url.startsWith("http")) {
return new ImageContent(url);
} else if (url.startsWith("data:")) {
String mimetype = DataUrlUtil.getImageType(url);
String base64String = DataUrlUtil.getDataBytesString(url);
return ImageContent.from(base64String, mimetype);
}
// TODO ImageURL.detail
return null;
}
}
22 changes: 22 additions & 0 deletions backend/src/main/java/ai/dragon/util/DataUrlUtil.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package ai.dragon.util;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.util.Base64;

public class DataUrlUtil {
Expand All @@ -20,6 +24,24 @@ public static String getImageType(String base64String) {
return getImageType(getDataBytes(getDataBytesString(base64String)));
}

public static String convertFileToDataImageBase64(File file) throws IOException {
String mimeType = Files.probeContentType(file.toPath());
return convertFileToDataImageBase64(file, mimeType != null ? mimeType : "application/octet-stream");
}

public static String convertFileToDataImageBase64(File file, String mimeType) throws IOException {
// Read the file into a byte array
byte[] fileBytes;
try (FileInputStream fileInputStream = new FileInputStream(file)) {
fileBytes = new byte[(int) file.length()];
fileInputStream.read(fileBytes);
}
// Encode the byte array to base64
String base64Encoded = Base64.getEncoder().encodeToString(fileBytes);
// Construct the data:image base64 string
return "data:" + mimeType + ";base64," + base64Encoded;
}

private static byte[] getDataBytes(String base64String) {
return Base64.getDecoder().decode(base64String);
}
Expand Down
Loading

0 comments on commit 3405c84

Please sign in to comment.