Skip to content

Commit

Permalink
RaaG : Ability to forward text and images to OpenAI API
Browse files Browse the repository at this point in the history
  • Loading branch information
amengus87 committed Jun 12, 2024
1 parent 41624a9 commit b271a32
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 7 deletions.
17 changes: 12 additions & 5 deletions backend/src/main/java/ai/dragon/service/ChatMessageService.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.springframework.stereotype.Service;

import ai.dragon.dto.openai.completion.OpenAiCompletionMessage;
import ai.dragon.util.DataUrlUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.Content;
Expand All @@ -34,11 +35,11 @@ public String singleTextFrom(UserMessage message) {

@SuppressWarnings("unchecked")
public Optional<ChatMessage> convertToChatMessage(OpenAiCompletionMessage completionMessage) {
ChatMessage chatMessage;
ChatMessage chatMessage = null;
switch (completionMessage.getRole()) {
case "user":
if (completionMessage.getContent() instanceof String) {
// TODO name
// TODO "UserMessage.name"
chatMessage = new UserMessage((String) completionMessage.getContent());
} else {
List<Map<String, Object>> content = (List<Map<String, Object>>) completionMessage.getContent();
Expand All @@ -55,11 +56,17 @@ public Optional<ChatMessage> convertToChatMessage(OpenAiCompletionMessage comple
} else if ("image_url".equals(type)) {
Map<String, Object> imageURL = (Map<String, Object>) contentItem.get("image_url");
String url = (String) imageURL.get("url");
if (url.startsWith("http")) {
contents.add(new ImageContent(url));
} else if (url.startsWith("data:")) {
String mimetype = DataUrlUtil.getImageType(url);
String base64String = DataUrlUtil.getDataBytesString(url);
contents.add(ImageContent.from(base64String, mimetype));
}
// TODO String detail = (String) imageURL.get("detail");
contents.add(new ImageContent(url));
}
});
// TODO name
// TODO "UserMessage.name"
chatMessage = new UserMessage(contents);
}
break;
Expand All @@ -70,7 +77,7 @@ public Optional<ChatMessage> convertToChatMessage(OpenAiCompletionMessage comple
chatMessage = new AiMessage((String) completionMessage.getContent());
break;
default:
throw new IllegalArgumentException("Invalid Message Role: " + completionMessage.getRole());
logger.error(String.format("Invalid Message Role '%s'", completionMessage.getRole()));
}
return Optional.ofNullable(chatMessage);
}
Expand Down
3 changes: 1 addition & 2 deletions backend/src/main/java/ai/dragon/service/RaagService.java
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ public SseEmitter chatResponse(FarmEntity farm, OpenAiChatCompletionRequest requ

private MessageWindowChatMemory buildChatMemory(OpenAiChatCompletionRequest request) {
MessageWindowChatMemory memory = MessageWindowChatMemory.withMaxMessages(10); // TODO maxMessages
// Don't retrieve the last message, it's not for history but for completion!
for (int i = 0; i < request.getMessages().size() - 1; i++) {
for (int i = 0; i < request.getMessages().size(); i++) {
OpenAiCompletionMessage requestMessage = request.getMessages().get(i);
chatMessageService.convertToChatMessage(requestMessage).ifPresent(memory::add);
}
Expand Down
75 changes: 75 additions & 0 deletions backend/src/main/java/ai/dragon/util/DataUrlUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package ai.dragon.util;

import java.util.Base64;

public class DataUrlUtil {
public static String getMimeType(String dataUrl) {
String[] parts = dataUrl.split(";");
if (parts.length < 2) {
return "octet/stream";
}
String mimeType = parts[0].split(":")[1];
return mimeType;
}

public static String getDataBytesString(String dataUrl) {
return dataUrl.substring(dataUrl.indexOf(",") + 1);
}

public static String getImageType(String base64String) {
return getImageType(getDataBytes(getDataBytesString(base64String)));
}

private static byte[] getDataBytes(String base64String) {
return Base64.getDecoder().decode(base64String);
}

// Forked from :
// https://nsclass.github.io/2017/03/13/java-an-example-code-to-extract-mime-type-from-base64-string-of-an-image/
private static String getImageType(byte[] data) {
// filetype magic number(hex)
// jpg FF D8 FF
// gif 47 49 46 38
// png 89 50 4E 47 0D 0A 1A 0A
// bmp 42 4D
// tiff(LE) 49 49 2A 00
// tiff(BE) 4D 4D 00 2A
final byte[] pngPattern = new byte[] { (byte) 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A };
final byte[] jpgPattern = new byte[] { (byte) 0xFF, (byte) 0xD8, (byte) 0xFF };
final byte[] gifPattern = new byte[] { 0x47, 0x49, 0x46, 0x38 };
final byte[] bmpPattern = new byte[] { 0x42, 0x4D };
final byte[] tiffLEPattern = new byte[] { 0x49, 0x49, 0x2A, 0x00 };
final byte[] tiffBEPattern = new byte[] { 0x4D, 0x4D, 0x00, 0x2A };
if (isMatch(pngPattern, data)) {
return "image/png";
}
if (isMatch(jpgPattern, data)) {
return "image/jpg";
}
if (isMatch(gifPattern, data)) {
return "image/gif";
}
if (isMatch(bmpPattern, data)) {
return "image/bmp";
}
if (isMatch(tiffLEPattern, data)) {
return "image/tif";
}
if (isMatch(tiffBEPattern, data)) {
return "image/tif";
}
return "image/png";
}

private static boolean isMatch(byte[] pattern, byte[] data) {
if (pattern.length <= data.length) {
for (int idx = 0; idx < pattern.length; ++idx) {
if (pattern[idx] != data[idx]) {
return false;
}
}
return true;
}
return false;
}
}

0 comments on commit b271a32

Please sign in to comment.