diff --git a/backend/checkstyle.xml b/backend/checkstyle.xml index ed08fdc0..ab4dd7b7 100644 --- a/backend/checkstyle.xml +++ b/backend/checkstyle.xml @@ -32,7 +32,11 @@ - + + + + + diff --git a/backend/src/main/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiController.java b/backend/src/main/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiController.java index cd26752a..9d65a005 100644 --- a/backend/src/main/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiController.java +++ b/backend/src/main/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiController.java @@ -5,11 +5,13 @@ import java.util.UUID; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpStatus; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.server.ResponseStatusException; import ai.dragon.dto.openai.completion.OpenAiChatCompletionChoice; import ai.dragon.dto.openai.completion.OpenAiChatCompletionRequest; @@ -19,8 +21,10 @@ import ai.dragon.dto.openai.completion.OpenAiCompletionRequest; import ai.dragon.dto.openai.completion.OpenAiCompletionResponse; import ai.dragon.dto.openai.completion.OpenAiCompletionUsage; -import ai.dragon.dto.openai.model.OpenAiModel; import ai.dragon.dto.openai.model.OpenAiModelsReponse; +import ai.dragon.entity.FarmEntity; +import ai.dragon.repository.FarmRepository; +import ai.dragon.service.RaagService; import ai.dragon.service.SseService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.media.Content; @@ -36,16 +40,17 @@ public class OpenAiCompatibleV1ApiController { @Autowired private SseService sseService; + @Autowired + private FarmRepository farmRepository; + + @Autowired + private RaagService raagService; + @GetMapping("/models") @Operation(summary = "List models", description = "Lists available models.") public OpenAiModelsReponse models() { return OpenAiModelsReponse.builder() - .data(List.of(OpenAiModel - .builder() - .created(System.currentTimeMillis() / 1000) - .id("dragon-ppx") - .owned_by("dRAGon") - .build())) + .data(raagService.listAvailableModels()) .build(); } @@ -105,63 +110,35 @@ public Object completions(@Valid @RequestBody OpenAiCompletionRequest request) @Operation(summary = "Creates a chat completion", description = "Creates a chat completion for the provided prompt and parameters.") @ApiResponse(responseCode = "200", content = @Content(schema = @Schema(implementation = OpenAiChatCompletionResponse.class))) public Object chatCompletions(@Valid @RequestBody OpenAiChatCompletionRequest request) throws Exception { + FarmEntity farm = farmRepository + .findUniqueByFieldValue("raagIdentifier", request.getModel()) + .orElseThrow(() -> new ResponseStatusException(HttpStatus.NOT_FOUND, "Farm not found")); if (Boolean.TRUE.equals(request.getStream())) { - UUID emitterID = sseService.createEmitter(); - for (int i = 0; i < 3; i++) { - OpenAiChatCompletionResponse responseChunk = new OpenAiChatCompletionResponse(); - responseChunk.setId(emitterID.toString()); - responseChunk.setModel(request.getModel()); - responseChunk.setCreated(System.currentTimeMillis() / 1000); - responseChunk.setObject("chat.completion.chunk"); - List choices = new ArrayList<>(); - choices.add(OpenAiChatCompletionChoice - .builder() - .index(0) - .finish_reason(i == 2 ? "stop" : null) - .delta(OpenAiCompletionMessage - .builder() - .role("assistant") - .content("Chunk : " + i + "\r\n") - .build()) - .build()); - responseChunk.setChoices(choices); - sseService.sendEvent(emitterID, responseChunk); - } - sseService.sendEvent(emitterID, "[DONE]"); - new Thread(() -> { - try { - Thread.sleep(500); - } catch (InterruptedException e) { - e.printStackTrace(); - } - sseService.complete(emitterID); - }).start(); - return sseService.retrieveEmitter(emitterID); + return raagService.chatResponse(farm, request); } else { - OpenAiChatCompletionResponse response = new OpenAiChatCompletionResponse(); - response.setId(UUID.randomUUID().toString()); - response.setModel(request.getModel()); - response.setCreated(System.currentTimeMillis() / 1000); - response.setObject("chat.completion"); - response.setUsage(OpenAiCompletionUsage - .builder() - .completion_tokens(0) - .prompt_tokens(0) - .total_tokens(0) - .build()); - List choices = new ArrayList<>(); - choices.add(OpenAiChatCompletionChoice + return OpenAiChatCompletionResponse .builder() - .index(0) - .finish_reason("stop") - .message(OpenAiCompletionMessage + .id(UUID.randomUUID().toString()) + .model(request.getModel()) + .created(System.currentTimeMillis() / 1000) + .object("chat.completion") + .usage(OpenAiCompletionUsage .builder() - .role("assistant") - .content("Hello, how can I help you today?") + .completion_tokens(0) + .prompt_tokens(0) + .total_tokens(0) .build()) - .build()); - response.setChoices(choices); - return response; + .choices(List.of(OpenAiChatCompletionChoice + .builder() + .index(0) + .finish_reason("stop") + .message(OpenAiCompletionMessage + .builder() + .role("assistant") + .content("Hello, how can I help you today?") + .build()) + .build())) + .build(); } } } diff --git a/backend/src/main/java/ai/dragon/dto/llm/StreamingChatLanguageModelDefinition.java b/backend/src/main/java/ai/dragon/dto/llm/StreamingChatLanguageModelDefinition.java new file mode 100644 index 00000000..7027e87d --- /dev/null +++ b/backend/src/main/java/ai/dragon/dto/llm/StreamingChatLanguageModelDefinition.java @@ -0,0 +1,16 @@ +package ai.dragon.dto.llm; + +import java.util.function.Function; + +import ai.dragon.enumeration.ProviderType; +import ai.dragon.properties.embedding.LanguageModelSettings; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import lombok.Builder; +import lombok.Data; + +@Data +@Builder +public class StreamingChatLanguageModelDefinition { + private Function modelWithSettings; + private ProviderType providerType; +} diff --git a/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiChatCompletionRequest.java b/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiChatCompletionRequest.java index 27a33450..67ba308c 100644 --- a/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiChatCompletionRequest.java +++ b/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiChatCompletionRequest.java @@ -2,6 +2,7 @@ import java.util.List; +import io.swagger.v3.oas.annotations.media.Schema; import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.NotEmpty; import jakarta.validation.constraints.NotNull; @@ -11,6 +12,7 @@ public class OpenAiChatCompletionRequest { @NotBlank @NotNull + @Schema(description = "Name of the Farm 'Raag Model' to be used.") private String model; @NotEmpty diff --git a/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiChatCompletionResponse.java b/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiChatCompletionResponse.java index 2fe41736..9c0cffaa 100644 --- a/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiChatCompletionResponse.java +++ b/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiChatCompletionResponse.java @@ -2,9 +2,11 @@ import java.util.List; +import lombok.Builder; import lombok.Data; @Data +@Builder public class OpenAiChatCompletionResponse { private String id; private String object; diff --git a/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiCompletionRequest.java b/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiCompletionRequest.java index 777aa35b..5fd17354 100644 --- a/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiCompletionRequest.java +++ b/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiCompletionRequest.java @@ -1,7 +1,7 @@ package ai.dragon.dto.openai.completion; import jakarta.validation.constraints.NotNull; - +import io.swagger.v3.oas.annotations.media.Schema; import jakarta.validation.constraints.NotBlank; import lombok.Data; @@ -9,6 +9,7 @@ public class OpenAiCompletionRequest { @NotBlank @NotNull + @Schema(description = "Name of the Farm 'Raag Model' to be used.") private String model; @NotNull diff --git a/backend/src/main/java/ai/dragon/entity/FarmEntity.java b/backend/src/main/java/ai/dragon/entity/FarmEntity.java index 8350604a..085d5266 100644 --- a/backend/src/main/java/ai/dragon/entity/FarmEntity.java +++ b/backend/src/main/java/ai/dragon/entity/FarmEntity.java @@ -10,14 +10,19 @@ import org.dizitart.no2.repository.annotations.Index; import org.dizitart.no2.repository.annotations.Indices; +import ai.dragon.enumeration.LanguageModelType; import io.swagger.v3.oas.annotations.media.Schema; import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Pattern; import lombok.Getter; import lombok.Setter; @Entity(value = "farm") @Schema(name = "Farm", description = "Farm Entity") -@Indices({ @Index(fields = "name", type = IndexType.UNIQUE) }) +@Indices({ + @Index(fields = "name", type = IndexType.UNIQUE), + @Index(fields = "raagIdentifier", type = IndexType.UNIQUE) +}) @Getter @Setter public class FarmEntity implements AbstractEntity { @@ -30,14 +35,26 @@ public class FarmEntity implements AbstractEntity { @Schema(description = "Name of the Farm. Must be unique.") private String name; + @Schema(description = "Identifier for the 'Raag Model' to be used for the RaaG API. Must be unique") + @Pattern(regexp = "^[a-z0-9\\-]+$", message = "Must be alphanumeric, hyphens allowed") + private String raagIdentifier; + @Schema(description = """ List of Silo UUIDs to be linked to the Farm. - A farm is a collection of Silos, each Silo is a collection of Documents.""") + A Farm is a collection of Silos, which each Silo is a collection of Documents.""") private List silos; + @Schema(description = "Language Model to be used for the RaaG API") + private LanguageModelType languageModel; + + @Schema(description = "Settings to be linked to the Farm's Language Model in the form of `key = value` pairs.") + private List languageModelSettings; + public FarmEntity() { this.uuid = UUID.randomUUID(); this.name = String.format("Farm %s", this.uuid.toString()); this.silos = new ArrayList(); + this.raagIdentifier = UUID.randomUUID().toString(); + this.languageModel = LanguageModelType.OpenAiModel; } } diff --git a/backend/src/main/java/ai/dragon/enumeration/EmbeddingModelType.java b/backend/src/main/java/ai/dragon/enumeration/EmbeddingModelType.java index 68c6145a..58b3148b 100644 --- a/backend/src/main/java/ai/dragon/enumeration/EmbeddingModelType.java +++ b/backend/src/main/java/ai/dragon/enumeration/EmbeddingModelType.java @@ -28,6 +28,11 @@ public static EmbeddingModelType fromString(String text) { return null; } + @Override + public String toString() { + return value; + } + public EmbeddingModelDefinition getModelDefinition() throws ClassNotFoundException { switch (this) { case BgeSmallEnV15QuantizedEmbeddingModel: @@ -93,9 +98,4 @@ public EmbeddingModelDefinition getModelDefinition() throws ClassNotFoundExcepti throw new ClassNotFoundException("Model not found"); } } - - @Override - public String toString() { - return value; - } } diff --git a/backend/src/main/java/ai/dragon/enumeration/LanguageModelType.java b/backend/src/main/java/ai/dragon/enumeration/LanguageModelType.java new file mode 100644 index 00000000..98f4d1b0 --- /dev/null +++ b/backend/src/main/java/ai/dragon/enumeration/LanguageModelType.java @@ -0,0 +1,51 @@ +package ai.dragon.enumeration; + +import java.time.Duration; + +import ai.dragon.dto.llm.StreamingChatLanguageModelDefinition; +import ai.dragon.service.SseService; +import dev.langchain4j.model.openai.OpenAiStreamingChatModel; + +public enum LanguageModelType { + OpenAiModel("OpenAiModel"); + + private String value; + + LanguageModelType(String value) { + this.value = value; + } + + public static LanguageModelType fromString(String text) { + for (LanguageModelType b : LanguageModelType.values()) { + if (b.value.equalsIgnoreCase(text)) { + return b; + } + } + return null; + } + + @Override + public String toString() { + return value; + } + + public StreamingChatLanguageModelDefinition getStreamingChatLanguageModel() throws ClassNotFoundException { + switch (this) { + case OpenAiModel: + return StreamingChatLanguageModelDefinition + .builder() + .modelWithSettings(parameters -> { + return OpenAiStreamingChatModel + .builder() + .apiKey(parameters.getApiKey()) + .modelName(parameters.getModelName()) + .timeout(Duration.ofSeconds(SseService.DEFAULT_TIMEOUT)) + .build(); + }) + .providerType(ProviderType.OpenAI) + .build(); + default: + throw new ClassNotFoundException("Model not found"); + } + } +} diff --git a/backend/src/main/java/ai/dragon/job/silo/ingestor/loader/filesystem/FileSystemIngestorLoader.java b/backend/src/main/java/ai/dragon/job/silo/ingestor/loader/filesystem/FileSystemIngestorLoader.java index 63f9a21e..30a6e6a1 100644 --- a/backend/src/main/java/ai/dragon/job/silo/ingestor/loader/filesystem/FileSystemIngestorLoader.java +++ b/backend/src/main/java/ai/dragon/job/silo/ingestor/loader/filesystem/FileSystemIngestorLoader.java @@ -77,7 +77,7 @@ public void checkIngestorLoaderSettings() throws Exception { String[] paths = loaderSettings.getPath().trim().split(","); pathsToIngest.clear(); for (String path : paths) { - File pathFile = new File(path); + File pathFile = new File(path.trim()); if (!pathFile.exists() || !pathFile.isDirectory()) { logger.warn("Skipping directory because not found : {}", pathFile); continue; diff --git a/backend/src/main/java/ai/dragon/properties/embedding/LanguageModelSettings.java b/backend/src/main/java/ai/dragon/properties/embedding/LanguageModelSettings.java new file mode 100644 index 00000000..aa41ea33 --- /dev/null +++ b/backend/src/main/java/ai/dragon/properties/embedding/LanguageModelSettings.java @@ -0,0 +1,12 @@ +package ai.dragon.properties.embedding; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; + +import lombok.Data; + +@Data +@JsonIgnoreProperties(ignoreUnknown = true) +public class LanguageModelSettings { + private String apiKey; + private String modelName; +} diff --git a/backend/src/main/java/ai/dragon/repository/AbstractRepository.java b/backend/src/main/java/ai/dragon/repository/AbstractRepository.java index 164636b6..67476391 100644 --- a/backend/src/main/java/ai/dragon/repository/AbstractRepository.java +++ b/backend/src/main/java/ai/dragon/repository/AbstractRepository.java @@ -87,6 +87,14 @@ public Cursor findByFieldValue(String fieldName, Object fieldValue) { return this.findWithFilter(FluentFilter.where(fieldName).eq(fieldValue)); } + public Optional findUniqueByFieldValue(String fieldName, Object fieldValue) { + Cursor cursor = this.findByFieldValue(fieldName, fieldValue); + if (cursor.size() > 1) { + throw new ResponseStatusException(HttpStatus.CONFLICT, "Multiple entities found"); + } + return cursor.size() == 1 ? Optional.of(cursor.firstOrNull()) : Optional.empty(); + } + public void delete(String uuid) { delete(UUID.fromString(uuid)); } diff --git a/backend/src/main/java/ai/dragon/service/ChatMessageService.java b/backend/src/main/java/ai/dragon/service/ChatMessageService.java new file mode 100644 index 00000000..776167b2 --- /dev/null +++ b/backend/src/main/java/ai/dragon/service/ChatMessageService.java @@ -0,0 +1,77 @@ +package ai.dragon.service; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Service; + +import ai.dragon.dto.openai.completion.OpenAiCompletionMessage; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.Content; +import dev.langchain4j.data.message.ImageContent; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.TextContent; +import dev.langchain4j.data.message.UserMessage; + +@Service +public class ChatMessageService { + private final Logger logger = LoggerFactory.getLogger(this.getClass()); + + public String singleTextFrom(UserMessage message) { + StringBuilder sb = new StringBuilder(); + message.contents().forEach(content -> { + if (content instanceof TextContent) { + sb.append(((TextContent) content).text()); + } + }); + return sb.toString(); + } + + @SuppressWarnings("unchecked") + public Optional convertToChatMessage(OpenAiCompletionMessage completionMessage) { + ChatMessage chatMessage; + switch (completionMessage.getRole()) { + case "user": + if (completionMessage.getContent() instanceof String) { + // TODO name + chatMessage = new UserMessage((String) completionMessage.getContent()); + } else { + List> content = (List>) completionMessage.getContent(); + List contents = new ArrayList<>(); + content.forEach(contentItem -> { + if (!contentItem.containsKey("type")) { + logger.error("Content part must have a type field!"); + return; + } + String type = (String) contentItem.get("type"); + if ("text".equals(type)) { + String text = (String) contentItem.get("text"); + contents.add(new TextContent(text)); + } else if ("image_url".equals(type)) { + Map imageURL = (Map) contentItem.get("image_url"); + String url = (String) imageURL.get("url"); + // TODO String detail = (String) imageURL.get("detail"); + contents.add(new ImageContent(url)); + } + }); + // TODO name + chatMessage = new UserMessage(contents); + } + break; + case "system": + chatMessage = new SystemMessage((String) completionMessage.getContent()); + break; + case "assistant": + chatMessage = new AiMessage((String) completionMessage.getContent()); + break; + default: + throw new IllegalArgumentException("Invalid Message Role: " + completionMessage.getRole()); + } + return Optional.ofNullable(chatMessage); + } +} diff --git a/backend/src/main/java/ai/dragon/service/EmbeddingModelService.java b/backend/src/main/java/ai/dragon/service/EmbeddingModelService.java index 38aa43f6..6376f128 100644 --- a/backend/src/main/java/ai/dragon/service/EmbeddingModelService.java +++ b/backend/src/main/java/ai/dragon/service/EmbeddingModelService.java @@ -1,10 +1,15 @@ package ai.dragon.service; +import java.util.UUID; + import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpStatus; import org.springframework.stereotype.Service; +import org.springframework.web.server.ResponseStatusException; import ai.dragon.entity.SiloEntity; import ai.dragon.properties.embedding.EmbeddingSettings; +import ai.dragon.repository.SiloRepository; import dev.langchain4j.model.embedding.EmbeddingModel; @Service @@ -12,7 +17,16 @@ public class EmbeddingModelService { @Autowired private KVSettingService kvSettingService; - public EmbeddingModel modelForEntity(SiloEntity siloEntity) throws Exception { + @Autowired + private SiloRepository siloRepository; + + public EmbeddingModel modelForSilo(UUID siloUuid) throws Exception { + SiloEntity siloEntity = siloRepository.getByUuid(siloUuid) + .orElseThrow(() -> new ResponseStatusException(HttpStatus.NOT_FOUND, "Silo not found")); + return modelForSilo(siloEntity); + } + + public EmbeddingModel modelForSilo(SiloEntity siloEntity) throws Exception { EmbeddingSettings embeddingSettings = kvSettingService.kvSettingsToObject( siloEntity.getEmbeddingSettings(), EmbeddingSettings.class); return siloEntity.getEmbeddingModel().getModelDefinition().getEmbeddingModelWithSettings() diff --git a/backend/src/main/java/ai/dragon/service/EmbeddingStoreService.java b/backend/src/main/java/ai/dragon/service/EmbeddingStoreService.java index 153edf2c..240f115e 100644 --- a/backend/src/main/java/ai/dragon/service/EmbeddingStoreService.java +++ b/backend/src/main/java/ai/dragon/service/EmbeddingStoreService.java @@ -96,7 +96,7 @@ public void clearEmbeddingStore(UUID siloUuid) throws Exception { public EmbeddingSearchResult query(UUID siloUuid, String query, Integer maxResults) throws Exception { SiloEntity siloEntity = siloRepository.getByUuid(siloUuid).orElseThrow(); EmbeddingStore embeddingStore = retrieveEmbeddingStore(siloUuid); - EmbeddingModel embeddingModel = embeddingModelService.modelForEntity(siloEntity); + EmbeddingModel embeddingModel = embeddingModelService.modelForSilo(siloEntity); Embedding queryEmbedding = embeddingModel.embed(query).content(); // Filter onlyForUser1 = metadataKey("userId").isEqualTo("1"); EmbeddingSearchRequest embeddingSearchRequest1 = EmbeddingSearchRequest.builder() diff --git a/backend/src/main/java/ai/dragon/service/IngestorService.java b/backend/src/main/java/ai/dragon/service/IngestorService.java index 276bba51..c798e3c9 100644 --- a/backend/src/main/java/ai/dragon/service/IngestorService.java +++ b/backend/src/main/java/ai/dragon/service/IngestorService.java @@ -74,7 +74,7 @@ private void ingestDocumentsToSilo(List documents, SiloEntity siloEnti throws Exception { EmbeddingStore embeddingStore = embeddingStoreService .retrieveEmbeddingStore(siloEntity.getUuid()); - EmbeddingModel embeddingModel = embeddingModelService.modelForEntity(siloEntity); + EmbeddingModel embeddingModel = embeddingModelService.modelForSilo(siloEntity); EmbeddingStoreIngestor ingestor = buildIngestor(embeddingStore, embeddingModel, siloEntity); logCallback.accept(SiloIngestLoaderLogMessage.builder() .message(String.format( diff --git a/backend/src/main/java/ai/dragon/service/RaagService.java b/backend/src/main/java/ai/dragon/service/RaagService.java new file mode 100644 index 00000000..89ac9a83 --- /dev/null +++ b/backend/src/main/java/ai/dragon/service/RaagService.java @@ -0,0 +1,187 @@ +package ai.dragon.service; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.UUID; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import ai.dragon.dto.openai.completion.OpenAiChatCompletionChoice; +import ai.dragon.dto.openai.completion.OpenAiChatCompletionRequest; +import ai.dragon.dto.openai.completion.OpenAiChatCompletionResponse; +import ai.dragon.dto.openai.completion.OpenAiCompletionMessage; +import ai.dragon.dto.openai.model.OpenAiModel; +import ai.dragon.entity.FarmEntity; +import ai.dragon.properties.embedding.LanguageModelSettings; +import ai.dragon.repository.FarmRepository; +import ai.dragon.util.ai.AiAssistant; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.rag.DefaultRetrievalAugmentor; +import dev.langchain4j.rag.RetrievalAugmentor; +import dev.langchain4j.rag.content.retriever.ContentRetriever; +import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever; +import dev.langchain4j.rag.query.router.DefaultQueryRouter; +import dev.langchain4j.service.AiServices; +import dev.langchain4j.service.TokenStream; +import dev.langchain4j.store.embedding.EmbeddingStore; + +@Service +public class RaagService { + private final Logger logger = LoggerFactory.getLogger(this.getClass()); + + @Autowired + private EmbeddingStoreService embeddingStoreService; + + @Autowired + private EmbeddingModelService embeddingModelService; + + @Autowired + private SseService sseService; + + @Autowired + private FarmRepository farmRepository; + + @Autowired + private KVSettingService kvSettingService; + + @Autowired + private ChatMessageService chatMessageService; + + public List listAvailableModels() { + return farmRepository + .find() + .toList() + .stream() + .map(farm -> { + return OpenAiModel + .builder() + .created(System.currentTimeMillis() / 1000) + .id(farm.getRaagIdentifier()) + .owned_by("dRAGon RaaG") + .build(); + }) + .toList(); + } + + public SseEmitter chatResponse(FarmEntity farm, OpenAiChatCompletionRequest request) throws Exception { + AiAssistant assistant = AiServices.builder(AiAssistant.class) + .streamingChatLanguageModel(this.buildStreamingChatLanguageModel(farm)) + // TODO support of chatLanguageModel in addition of streamingChatLanguageModel + .retrievalAugmentor(this.buildRetrievalAugmentor(farm)) + .chatMemory(this.buildChatMemory(request)) + .build(); + OpenAiCompletionMessage lastCompletionMessage = request.getMessages().get(request.getMessages().size() - 1); + UserMessage lastChatMessage = (UserMessage) chatMessageService.convertToChatMessage(lastCompletionMessage) + .orElseThrow(); + TokenStream stream = assistant.chat(chatMessageService.singleTextFrom(lastChatMessage)); + UUID emitterID = sseService.createEmitter(); + stream + .onNext(nextChunk -> { + sseService.sendEvent(emitterID, + this.createChatCompletionResponse(emitterID, request, nextChunk, false)); + }) + .onComplete(response -> { + sseService.sendEvent(emitterID, + this.createChatCompletionResponse(emitterID, request, "", true)); + sseService.sendEvent(emitterID, "[DONE]"); + sseService.complete(emitterID); + }) + .onError(Throwable::printStackTrace) + .start(); + return sseService.retrieveEmitter(emitterID); + } + + private MessageWindowChatMemory buildChatMemory(OpenAiChatCompletionRequest request) { + MessageWindowChatMemory memory = MessageWindowChatMemory.withMaxMessages(10); // TODO maxMessages + // Don't retrieve the last message, it's not for history but for completion! + for (int i = 0; i < request.getMessages().size() - 1; i++) { + OpenAiCompletionMessage requestMessage = request.getMessages().get(i); + chatMessageService.convertToChatMessage(requestMessage).ifPresent(memory::add); + } + return memory; + } + + private OpenAiChatCompletionResponse createChatCompletionResponse( + UUID emitterID, + OpenAiChatCompletionRequest request, + String nextChunk, + boolean isLastChunk) { + return OpenAiChatCompletionResponse + .builder() + .id(emitterID.toString()) + .model(request.getModel()) + .created(System.currentTimeMillis() / 1000) + .object("chat.completion.chunk") + .choices(List.of(OpenAiChatCompletionChoice + .builder() + .index(0) + .finish_reason(isLastChunk ? "stop" : null) + .delta(OpenAiCompletionMessage + .builder() + .role("assistant") + .content(nextChunk) + .build()) + .build())) + .build(); + } + + private StreamingChatLanguageModel buildStreamingChatLanguageModel(FarmEntity farm) throws Exception { + return farm + .getLanguageModel() + .getStreamingChatLanguageModel() + .getModelWithSettings() + .apply(kvSettingService + .kvSettingsToObject(farm.getLanguageModelSettings(), + LanguageModelSettings.class)); + } + + private RetrievalAugmentor buildRetrievalAugmentor(FarmEntity farm) { + return DefaultRetrievalAugmentor.builder() + .queryRouter(new DefaultQueryRouter(this.buildRetrieverList(farm))) + .build(); + } + + public List buildRetrieverList(FarmEntity farm) { + List retrievers = new ArrayList<>(); + if (farm.getSilos() == null || farm.getSilos().isEmpty()) { + logger.warn("No Silos found for Farm '{}' (RaaG Identifier '{}'), no content retrieve will be made", + farm.getUuid(), farm.getRaagIdentifier()); + return retrievers; + } + farm.getSilos().forEach(siloUuid -> { + try { + this.buildRetriever(siloUuid).ifPresent(retrievers::add); + } catch (Exception ex) { + logger.error("Error building Content Retriever for Silo '{}'", siloUuid, ex); + } + }); + return retrievers; + } + + public Optional buildRetriever(UUID siloUuid) throws Exception { + EmbeddingModel embeddingModel = embeddingModelService.modelForSilo(siloUuid); + EmbeddingStore embeddingStore = embeddingStoreService.retrieveEmbeddingStore(siloUuid); + return Optional.of(EmbeddingStoreContentRetriever.builder() + .embeddingStore(embeddingStore) + .embeddingModel(embeddingModel) + .dynamicMaxResults(query -> { + return 10; // TODO SiloEntity or FarmEntity settings + }) + .dynamicMinScore(query -> { + return 0.8; // TODO SiloEntity or FarmEntity settings + }) + .dynamicFilter(query -> { + return null; // TODO SiloEntity or FarmEntity settings + }) + .build()); + } +} diff --git a/backend/src/main/java/ai/dragon/service/SseService.java b/backend/src/main/java/ai/dragon/service/SseService.java index 3e4f65f3..499a6162 100644 --- a/backend/src/main/java/ai/dragon/service/SseService.java +++ b/backend/src/main/java/ai/dragon/service/SseService.java @@ -11,12 +11,15 @@ @Service public class SseService { + // TODO Custom timeout : + public static final long DEFAULT_TIMEOUT = 90L * 1000; + private final Logger logger = LoggerFactory.getLogger(this.getClass()); private ConcurrentHashMap emitters = new ConcurrentHashMap<>(); public UUID createEmitter() { - return createEmitter(10L * 1000L); + return createEmitter(DEFAULT_TIMEOUT); } public UUID createEmitter(Long timeout) { @@ -36,27 +39,31 @@ public SseEmitter retrieveEmitter(UUID id) { return emitters.get(id); } - public void complete(UUID id) { + public boolean complete(UUID id) { SseEmitter emitter = emitters.get(id); if (emitter == null) { - logger.warn("No emitter found for id: {}", id); - return; + logger.warn("Can't complete : No emitter found for id '{}'", id); + return false; } emitter.complete(); emitters.remove(id); + return true; } - public void sendEvent(UUID id, Object event) { + public boolean sendEvent(UUID id, Object event) { SseEmitter emitter = emitters.get(id); if (emitter == null) { - logger.warn("No emitter found for id: {}", id); - return; + logger.info(event.toString()); + logger.warn("Can't send event : No emitter found for id '{}'", id); + return false; } try { emitter.send(event); } catch (IOException e) { emitter.complete(); emitters.remove(id); + return false; } + return true; } } diff --git a/backend/src/main/java/ai/dragon/util/ai/AiAssistant.java b/backend/src/main/java/ai/dragon/util/ai/AiAssistant.java new file mode 100644 index 00000000..4aa84a0e --- /dev/null +++ b/backend/src/main/java/ai/dragon/util/ai/AiAssistant.java @@ -0,0 +1,9 @@ +package ai.dragon.util.ai; + +import dev.langchain4j.service.Result; +import dev.langchain4j.service.TokenStream; + +public interface AiAssistant { + Result answer(String query); + TokenStream chat(String message); +} diff --git a/backend/src/test/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiControllerTest.java b/backend/src/test/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiControllerTest.java index 73794bc8..e5406a3d 100644 --- a/backend/src/test/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiControllerTest.java +++ b/backend/src/test/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiControllerTest.java @@ -1,10 +1,12 @@ package ai.dragon.controller.api.raag; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import java.util.List; import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; import org.springframework.boot.test.web.server.LocalServerPort; @@ -13,17 +15,28 @@ import com.theokanning.openai.model.Model; import com.theokanning.openai.service.OpenAiService; +import ai.dragon.entity.FarmEntity; +import ai.dragon.repository.FarmRepository; + @SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT) @ActiveProfiles("test") public class OpenAiCompatibleV1ApiControllerTest { @LocalServerPort private int serverPort; + @Autowired + private FarmRepository farmRepository; + @Test void listModels() throws Exception { + farmRepository.deleteAll(); + FarmEntity farm = new FarmEntity(); + farm.setRaagIdentifier("awesome-raag"); + farmRepository.save(farm); OpenAiService service = new OpenAiService("TODO_PUT_KEY_HERE", String.format("http://localhost:%d/api/raag/v1/", serverPort)); List models = service.listModels(); assertFalse(models.isEmpty()); + assertEquals(1, models.size()); } } diff --git a/docs/docs/integrations/librechat.mdx b/docs/docs/integrations/librechat.mdx index 082ff6d0..e7b4c593 100644 --- a/docs/docs/integrations/librechat.mdx +++ b/docs/docs/integrations/librechat.mdx @@ -8,11 +8,11 @@ endpoints: - name: "dRAGon" apiKey: "YOUR_API_KEY" baseURL: "http://localhost:1985/api/raag/v1" - models: - default: ["dragon-silo-uuid", "dragon-farm-uuid"] + models: + default: ["your-raag-model-name"] fetch: true titleConvo: true - titleModel: "dragon-farm-uuid" + titleModel: "your-raag-model-name" titleMethod: "completion" modelDisplayLabel: "dRAGon" iconURL: "https://dragon.okinawa/img/dragon_okinawa_icon.png"