Skip to content

Commit

Permalink
Begin Chat Memory
Browse files Browse the repository at this point in the history
  • Loading branch information
amengus87 committed Jun 11, 2024
1 parent cc53399 commit 575cdd0
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 1 deletion.
62 changes: 62 additions & 0 deletions backend/src/main/java/ai/dragon/service/ChatMessageService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package ai.dragon.service;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;

import ai.dragon.dto.openai.completion.OpenAiCompletionMessage;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.UserMessage;

@Service
public class ChatMessageService {
private final Logger logger = LoggerFactory.getLogger(this.getClass());

public Optional<ChatMessage> convertToChatMessage(OpenAiCompletionMessage completionMessage) {
ChatMessage chatMessage;
switch (completionMessage.getRole()) {
case "user":
if (completionMessage.getContent() instanceof String) {
chatMessage = new UserMessage(completionMessage.getName(), (String) completionMessage.getContent());
} else {
List<Map<String, Object>> content = (List<Map<String, Object>>) completionMessage.getContent();
List<Content> contents = new ArrayList<>();
content.forEach(contentItem -> {
if (!contentItem.containsKey("type")) {
logger.error("Content part must have a type field!");
return;
}
String type = (String) contentItem.get("type");
if ("text".equals(type)) {
String text = (String) contentItem.get("text");
contents.add(new TextContent(text));
} else if ("image_url".equals(type)) {
Map<String, Object> imageURL = (Map<String, Object>) contentItem.get("image_url");
String url = (String) imageURL.get("url");
// TODO String detail = (String) imageURL.get("detail");
contents.add(new ImageContent(url));
}
});
chatMessage = new UserMessage(completionMessage.getName(), contents);
}
break;
case "system":
chatMessage = new SystemMessage((String) completionMessage.getContent());
case "assistant":
chatMessage = new AiMessage((String) completionMessage.getContent());
default:
throw new IllegalArgumentException("Invalid Message Role: " + completionMessage.getRole());
}
return Optional.ofNullable(chatMessage);
}
}
16 changes: 15 additions & 1 deletion backend/src/main/java/ai/dragon/service/RaagService.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ai.dragon.repository.FarmRepository;
import ai.dragon.util.ai.AiAssistant;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
Expand Down Expand Up @@ -51,6 +52,9 @@ public class RaagService {
@Autowired
private KVSettingService kvSettingService;

@Autowired
private ChatMessageService chatMessageService;

public List<OpenAiModel> listAvailableModels() {
return farmRepository
.find()
Expand All @@ -71,8 +75,8 @@ public SseEmitter chatResponse(FarmEntity farm, OpenAiChatCompletionRequest requ
AiAssistant assistant = AiServices.builder(AiAssistant.class)
.streamingChatLanguageModel(this.buildStreamingChatLanguageModel(farm))
// TODO support of chatLanguageModel in addition of streamingChatLanguageModel
// TODO .chatMemory(MessageWindowChatMemory.withMaxMessages(10)) // it should
.retrievalAugmentor(this.buildRetrievalAugmentor(farm))
.chatMemory(this.buildChatMemory(request))
.build();
// TODO request.getMessages().get(0).getContent())
TokenStream stream = assistant.chat((String) request.getMessages().get(0).getContent());
Expand All @@ -93,6 +97,16 @@ public SseEmitter chatResponse(FarmEntity farm, OpenAiChatCompletionRequest requ
return sseService.retrieveEmitter(emitterID);
}

private MessageWindowChatMemory buildChatMemory(OpenAiChatCompletionRequest request) {
MessageWindowChatMemory memory = MessageWindowChatMemory.withMaxMessages(10); // TODO maxMessages
// Don't retrieve the last message, it's not for history but for completion!
for (int i = 0; i < request.getMessages().size() - 1; i++) {
OpenAiCompletionMessage requestMessage = request.getMessages().get(i);
chatMessageService.convertToChatMessage(requestMessage).ifPresent(memory::add);
}
return memory;
}

private OpenAiChatCompletionResponse createChatCompletionResponse(
UUID emitterID,
OpenAiChatCompletionRequest request,
Expand Down

0 comments on commit 575cdd0

Please sign in to comment.