From 575cdd092faf9d0521d2cc9f1b94264891d537d9 Mon Sep 17 00:00:00 2001 From: Arnaud Mengus Date: Tue, 11 Jun 2024 15:26:23 +0000 Subject: [PATCH] Begin Chat Memory --- .../ai/dragon/service/ChatMessageService.java | 62 +++++++++++++++++++ .../java/ai/dragon/service/RaagService.java | 16 ++++- 2 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 backend/src/main/java/ai/dragon/service/ChatMessageService.java diff --git a/backend/src/main/java/ai/dragon/service/ChatMessageService.java b/backend/src/main/java/ai/dragon/service/ChatMessageService.java new file mode 100644 index 00000000..c10a18ca --- /dev/null +++ b/backend/src/main/java/ai/dragon/service/ChatMessageService.java @@ -0,0 +1,62 @@ +package ai.dragon.service; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Service; + +import ai.dragon.dto.openai.completion.OpenAiCompletionMessage; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.Content; +import dev.langchain4j.data.message.ImageContent; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.TextContent; +import dev.langchain4j.data.message.UserMessage; + +@Service +public class ChatMessageService { + private final Logger logger = LoggerFactory.getLogger(this.getClass()); + + public Optional convertToChatMessage(OpenAiCompletionMessage completionMessage) { + ChatMessage chatMessage; + switch (completionMessage.getRole()) { + case "user": + if (completionMessage.getContent() instanceof String) { + chatMessage = new UserMessage(completionMessage.getName(), (String) completionMessage.getContent()); + } else { + List> content = (List>) completionMessage.getContent(); + List contents = new ArrayList<>(); + content.forEach(contentItem -> { + if (!contentItem.containsKey("type")) { + logger.error("Content part must have a type field!"); + return; + } + String type = (String) contentItem.get("type"); + if ("text".equals(type)) { + String text = (String) contentItem.get("text"); + contents.add(new TextContent(text)); + } else if ("image_url".equals(type)) { + Map imageURL = (Map) contentItem.get("image_url"); + String url = (String) imageURL.get("url"); + // TODO String detail = (String) imageURL.get("detail"); + contents.add(new ImageContent(url)); + } + }); + chatMessage = new UserMessage(completionMessage.getName(), contents); + } + break; + case "system": + chatMessage = new SystemMessage((String) completionMessage.getContent()); + case "assistant": + chatMessage = new AiMessage((String) completionMessage.getContent()); + default: + throw new IllegalArgumentException("Invalid Message Role: " + completionMessage.getRole()); + } + return Optional.ofNullable(chatMessage); + } +} diff --git a/backend/src/main/java/ai/dragon/service/RaagService.java b/backend/src/main/java/ai/dragon/service/RaagService.java index e184478d..b88c6b90 100644 --- a/backend/src/main/java/ai/dragon/service/RaagService.java +++ b/backend/src/main/java/ai/dragon/service/RaagService.java @@ -21,6 +21,7 @@ import ai.dragon.repository.FarmRepository; import ai.dragon.util.ai.AiAssistant; import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.rag.DefaultRetrievalAugmentor; @@ -51,6 +52,9 @@ public class RaagService { @Autowired private KVSettingService kvSettingService; + @Autowired + private ChatMessageService chatMessageService; + public List listAvailableModels() { return farmRepository .find() @@ -71,8 +75,8 @@ public SseEmitter chatResponse(FarmEntity farm, OpenAiChatCompletionRequest requ AiAssistant assistant = AiServices.builder(AiAssistant.class) .streamingChatLanguageModel(this.buildStreamingChatLanguageModel(farm)) // TODO support of chatLanguageModel in addition of streamingChatLanguageModel - // TODO .chatMemory(MessageWindowChatMemory.withMaxMessages(10)) // it should .retrievalAugmentor(this.buildRetrievalAugmentor(farm)) + .chatMemory(this.buildChatMemory(request)) .build(); // TODO request.getMessages().get(0).getContent()) TokenStream stream = assistant.chat((String) request.getMessages().get(0).getContent()); @@ -93,6 +97,16 @@ public SseEmitter chatResponse(FarmEntity farm, OpenAiChatCompletionRequest requ return sseService.retrieveEmitter(emitterID); } + private MessageWindowChatMemory buildChatMemory(OpenAiChatCompletionRequest request) { + MessageWindowChatMemory memory = MessageWindowChatMemory.withMaxMessages(10); // TODO maxMessages + // Don't retrieve the last message, it's not for history but for completion! + for (int i = 0; i < request.getMessages().size() - 1; i++) { + OpenAiCompletionMessage requestMessage = request.getMessages().get(i); + chatMessageService.convertToChatMessage(requestMessage).ifPresent(memory::add); + } + return memory; + } + private OpenAiChatCompletionResponse createChatCompletionResponse( UUID emitterID, OpenAiChatCompletionRequest request,