diff --git a/backend/checkstyle.xml b/backend/checkstyle.xml
index ed08fdc0..ab4dd7b7 100644
--- a/backend/checkstyle.xml
+++ b/backend/checkstyle.xml
@@ -32,7 +32,11 @@
-
+
+
+
+
+
diff --git a/backend/src/main/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiController.java b/backend/src/main/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiController.java
index cd26752a..9d65a005 100644
--- a/backend/src/main/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiController.java
+++ b/backend/src/main/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiController.java
@@ -5,11 +5,13 @@
import java.util.UUID;
import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.http.HttpStatus;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
+import org.springframework.web.server.ResponseStatusException;
import ai.dragon.dto.openai.completion.OpenAiChatCompletionChoice;
import ai.dragon.dto.openai.completion.OpenAiChatCompletionRequest;
@@ -19,8 +21,10 @@
import ai.dragon.dto.openai.completion.OpenAiCompletionRequest;
import ai.dragon.dto.openai.completion.OpenAiCompletionResponse;
import ai.dragon.dto.openai.completion.OpenAiCompletionUsage;
-import ai.dragon.dto.openai.model.OpenAiModel;
import ai.dragon.dto.openai.model.OpenAiModelsReponse;
+import ai.dragon.entity.FarmEntity;
+import ai.dragon.repository.FarmRepository;
+import ai.dragon.service.RaagService;
import ai.dragon.service.SseService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.media.Content;
@@ -36,16 +40,17 @@ public class OpenAiCompatibleV1ApiController {
@Autowired
private SseService sseService;
+ @Autowired
+ private FarmRepository farmRepository;
+
+ @Autowired
+ private RaagService raagService;
+
@GetMapping("/models")
@Operation(summary = "List models", description = "Lists available models.")
public OpenAiModelsReponse models() {
return OpenAiModelsReponse.builder()
- .data(List.of(OpenAiModel
- .builder()
- .created(System.currentTimeMillis() / 1000)
- .id("dragon-ppx")
- .owned_by("dRAGon")
- .build()))
+ .data(raagService.listAvailableModels())
.build();
}
@@ -105,63 +110,35 @@ public Object completions(@Valid @RequestBody OpenAiCompletionRequest request)
@Operation(summary = "Creates a chat completion", description = "Creates a chat completion for the provided prompt and parameters.")
@ApiResponse(responseCode = "200", content = @Content(schema = @Schema(implementation = OpenAiChatCompletionResponse.class)))
public Object chatCompletions(@Valid @RequestBody OpenAiChatCompletionRequest request) throws Exception {
+ FarmEntity farm = farmRepository
+ .findUniqueByFieldValue("raagIdentifier", request.getModel())
+ .orElseThrow(() -> new ResponseStatusException(HttpStatus.NOT_FOUND, "Farm not found"));
if (Boolean.TRUE.equals(request.getStream())) {
- UUID emitterID = sseService.createEmitter();
- for (int i = 0; i < 3; i++) {
- OpenAiChatCompletionResponse responseChunk = new OpenAiChatCompletionResponse();
- responseChunk.setId(emitterID.toString());
- responseChunk.setModel(request.getModel());
- responseChunk.setCreated(System.currentTimeMillis() / 1000);
- responseChunk.setObject("chat.completion.chunk");
- List choices = new ArrayList<>();
- choices.add(OpenAiChatCompletionChoice
- .builder()
- .index(0)
- .finish_reason(i == 2 ? "stop" : null)
- .delta(OpenAiCompletionMessage
- .builder()
- .role("assistant")
- .content("Chunk : " + i + "\r\n")
- .build())
- .build());
- responseChunk.setChoices(choices);
- sseService.sendEvent(emitterID, responseChunk);
- }
- sseService.sendEvent(emitterID, "[DONE]");
- new Thread(() -> {
- try {
- Thread.sleep(500);
- } catch (InterruptedException e) {
- e.printStackTrace();
- }
- sseService.complete(emitterID);
- }).start();
- return sseService.retrieveEmitter(emitterID);
+ return raagService.chatResponse(farm, request);
} else {
- OpenAiChatCompletionResponse response = new OpenAiChatCompletionResponse();
- response.setId(UUID.randomUUID().toString());
- response.setModel(request.getModel());
- response.setCreated(System.currentTimeMillis() / 1000);
- response.setObject("chat.completion");
- response.setUsage(OpenAiCompletionUsage
- .builder()
- .completion_tokens(0)
- .prompt_tokens(0)
- .total_tokens(0)
- .build());
- List choices = new ArrayList<>();
- choices.add(OpenAiChatCompletionChoice
+ return OpenAiChatCompletionResponse
.builder()
- .index(0)
- .finish_reason("stop")
- .message(OpenAiCompletionMessage
+ .id(UUID.randomUUID().toString())
+ .model(request.getModel())
+ .created(System.currentTimeMillis() / 1000)
+ .object("chat.completion")
+ .usage(OpenAiCompletionUsage
.builder()
- .role("assistant")
- .content("Hello, how can I help you today?")
+ .completion_tokens(0)
+ .prompt_tokens(0)
+ .total_tokens(0)
.build())
- .build());
- response.setChoices(choices);
- return response;
+ .choices(List.of(OpenAiChatCompletionChoice
+ .builder()
+ .index(0)
+ .finish_reason("stop")
+ .message(OpenAiCompletionMessage
+ .builder()
+ .role("assistant")
+ .content("Hello, how can I help you today?")
+ .build())
+ .build()))
+ .build();
}
}
}
diff --git a/backend/src/main/java/ai/dragon/dto/llm/StreamingChatLanguageModelDefinition.java b/backend/src/main/java/ai/dragon/dto/llm/StreamingChatLanguageModelDefinition.java
new file mode 100644
index 00000000..7027e87d
--- /dev/null
+++ b/backend/src/main/java/ai/dragon/dto/llm/StreamingChatLanguageModelDefinition.java
@@ -0,0 +1,16 @@
+package ai.dragon.dto.llm;
+
+import java.util.function.Function;
+
+import ai.dragon.enumeration.ProviderType;
+import ai.dragon.properties.embedding.LanguageModelSettings;
+import dev.langchain4j.model.chat.StreamingChatLanguageModel;
+import lombok.Builder;
+import lombok.Data;
+
+@Data
+@Builder
+public class StreamingChatLanguageModelDefinition {
+ private Function modelWithSettings;
+ private ProviderType providerType;
+}
diff --git a/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiChatCompletionRequest.java b/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiChatCompletionRequest.java
index 27a33450..67ba308c 100644
--- a/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiChatCompletionRequest.java
+++ b/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiChatCompletionRequest.java
@@ -2,6 +2,7 @@
import java.util.List;
+import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
@@ -11,6 +12,7 @@
public class OpenAiChatCompletionRequest {
@NotBlank
@NotNull
+ @Schema(description = "Name of the Farm 'Raag Model' to be used.")
private String model;
@NotEmpty
diff --git a/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiChatCompletionResponse.java b/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiChatCompletionResponse.java
index 2fe41736..9c0cffaa 100644
--- a/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiChatCompletionResponse.java
+++ b/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiChatCompletionResponse.java
@@ -2,9 +2,11 @@
import java.util.List;
+import lombok.Builder;
import lombok.Data;
@Data
+@Builder
public class OpenAiChatCompletionResponse {
private String id;
private String object;
diff --git a/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiCompletionRequest.java b/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiCompletionRequest.java
index 777aa35b..5fd17354 100644
--- a/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiCompletionRequest.java
+++ b/backend/src/main/java/ai/dragon/dto/openai/completion/OpenAiCompletionRequest.java
@@ -1,7 +1,7 @@
package ai.dragon.dto.openai.completion;
import jakarta.validation.constraints.NotNull;
-
+import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import lombok.Data;
@@ -9,6 +9,7 @@
public class OpenAiCompletionRequest {
@NotBlank
@NotNull
+ @Schema(description = "Name of the Farm 'Raag Model' to be used.")
private String model;
@NotNull
diff --git a/backend/src/main/java/ai/dragon/entity/FarmEntity.java b/backend/src/main/java/ai/dragon/entity/FarmEntity.java
index 8350604a..085d5266 100644
--- a/backend/src/main/java/ai/dragon/entity/FarmEntity.java
+++ b/backend/src/main/java/ai/dragon/entity/FarmEntity.java
@@ -10,14 +10,19 @@
import org.dizitart.no2.repository.annotations.Index;
import org.dizitart.no2.repository.annotations.Indices;
+import ai.dragon.enumeration.LanguageModelType;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
+import jakarta.validation.constraints.Pattern;
import lombok.Getter;
import lombok.Setter;
@Entity(value = "farm")
@Schema(name = "Farm", description = "Farm Entity")
-@Indices({ @Index(fields = "name", type = IndexType.UNIQUE) })
+@Indices({
+ @Index(fields = "name", type = IndexType.UNIQUE),
+ @Index(fields = "raagIdentifier", type = IndexType.UNIQUE)
+})
@Getter
@Setter
public class FarmEntity implements AbstractEntity {
@@ -30,14 +35,26 @@ public class FarmEntity implements AbstractEntity {
@Schema(description = "Name of the Farm. Must be unique.")
private String name;
+ @Schema(description = "Identifier for the 'Raag Model' to be used for the RaaG API. Must be unique")
+ @Pattern(regexp = "^[a-z0-9\\-]+$", message = "Must be alphanumeric, hyphens allowed")
+ private String raagIdentifier;
+
@Schema(description = """
List of Silo UUIDs to be linked to the Farm.
- A farm is a collection of Silos, each Silo is a collection of Documents.""")
+ A Farm is a collection of Silos, which each Silo is a collection of Documents.""")
private List silos;
+ @Schema(description = "Language Model to be used for the RaaG API")
+ private LanguageModelType languageModel;
+
+ @Schema(description = "Settings to be linked to the Farm's Language Model in the form of `key = value` pairs.")
+ private List languageModelSettings;
+
public FarmEntity() {
this.uuid = UUID.randomUUID();
this.name = String.format("Farm %s", this.uuid.toString());
this.silos = new ArrayList();
+ this.raagIdentifier = UUID.randomUUID().toString();
+ this.languageModel = LanguageModelType.OpenAiModel;
}
}
diff --git a/backend/src/main/java/ai/dragon/enumeration/EmbeddingModelType.java b/backend/src/main/java/ai/dragon/enumeration/EmbeddingModelType.java
index 68c6145a..58b3148b 100644
--- a/backend/src/main/java/ai/dragon/enumeration/EmbeddingModelType.java
+++ b/backend/src/main/java/ai/dragon/enumeration/EmbeddingModelType.java
@@ -28,6 +28,11 @@ public static EmbeddingModelType fromString(String text) {
return null;
}
+ @Override
+ public String toString() {
+ return value;
+ }
+
public EmbeddingModelDefinition getModelDefinition() throws ClassNotFoundException {
switch (this) {
case BgeSmallEnV15QuantizedEmbeddingModel:
@@ -93,9 +98,4 @@ public EmbeddingModelDefinition getModelDefinition() throws ClassNotFoundExcepti
throw new ClassNotFoundException("Model not found");
}
}
-
- @Override
- public String toString() {
- return value;
- }
}
diff --git a/backend/src/main/java/ai/dragon/enumeration/LanguageModelType.java b/backend/src/main/java/ai/dragon/enumeration/LanguageModelType.java
new file mode 100644
index 00000000..98f4d1b0
--- /dev/null
+++ b/backend/src/main/java/ai/dragon/enumeration/LanguageModelType.java
@@ -0,0 +1,51 @@
+package ai.dragon.enumeration;
+
+import java.time.Duration;
+
+import ai.dragon.dto.llm.StreamingChatLanguageModelDefinition;
+import ai.dragon.service.SseService;
+import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
+
+public enum LanguageModelType {
+ OpenAiModel("OpenAiModel");
+
+ private String value;
+
+ LanguageModelType(String value) {
+ this.value = value;
+ }
+
+ public static LanguageModelType fromString(String text) {
+ for (LanguageModelType b : LanguageModelType.values()) {
+ if (b.value.equalsIgnoreCase(text)) {
+ return b;
+ }
+ }
+ return null;
+ }
+
+ @Override
+ public String toString() {
+ return value;
+ }
+
+ public StreamingChatLanguageModelDefinition getStreamingChatLanguageModel() throws ClassNotFoundException {
+ switch (this) {
+ case OpenAiModel:
+ return StreamingChatLanguageModelDefinition
+ .builder()
+ .modelWithSettings(parameters -> {
+ return OpenAiStreamingChatModel
+ .builder()
+ .apiKey(parameters.getApiKey())
+ .modelName(parameters.getModelName())
+ .timeout(Duration.ofSeconds(SseService.DEFAULT_TIMEOUT))
+ .build();
+ })
+ .providerType(ProviderType.OpenAI)
+ .build();
+ default:
+ throw new ClassNotFoundException("Model not found");
+ }
+ }
+}
diff --git a/backend/src/main/java/ai/dragon/job/silo/ingestor/loader/filesystem/FileSystemIngestorLoader.java b/backend/src/main/java/ai/dragon/job/silo/ingestor/loader/filesystem/FileSystemIngestorLoader.java
index 63f9a21e..30a6e6a1 100644
--- a/backend/src/main/java/ai/dragon/job/silo/ingestor/loader/filesystem/FileSystemIngestorLoader.java
+++ b/backend/src/main/java/ai/dragon/job/silo/ingestor/loader/filesystem/FileSystemIngestorLoader.java
@@ -77,7 +77,7 @@ public void checkIngestorLoaderSettings() throws Exception {
String[] paths = loaderSettings.getPath().trim().split(",");
pathsToIngest.clear();
for (String path : paths) {
- File pathFile = new File(path);
+ File pathFile = new File(path.trim());
if (!pathFile.exists() || !pathFile.isDirectory()) {
logger.warn("Skipping directory because not found : {}", pathFile);
continue;
diff --git a/backend/src/main/java/ai/dragon/properties/embedding/LanguageModelSettings.java b/backend/src/main/java/ai/dragon/properties/embedding/LanguageModelSettings.java
new file mode 100644
index 00000000..aa41ea33
--- /dev/null
+++ b/backend/src/main/java/ai/dragon/properties/embedding/LanguageModelSettings.java
@@ -0,0 +1,12 @@
+package ai.dragon.properties.embedding;
+
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+
+import lombok.Data;
+
+@Data
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class LanguageModelSettings {
+ private String apiKey;
+ private String modelName;
+}
diff --git a/backend/src/main/java/ai/dragon/repository/AbstractRepository.java b/backend/src/main/java/ai/dragon/repository/AbstractRepository.java
index 164636b6..67476391 100644
--- a/backend/src/main/java/ai/dragon/repository/AbstractRepository.java
+++ b/backend/src/main/java/ai/dragon/repository/AbstractRepository.java
@@ -87,6 +87,14 @@ public Cursor findByFieldValue(String fieldName, Object fieldValue) {
return this.findWithFilter(FluentFilter.where(fieldName).eq(fieldValue));
}
+ public Optional findUniqueByFieldValue(String fieldName, Object fieldValue) {
+ Cursor cursor = this.findByFieldValue(fieldName, fieldValue);
+ if (cursor.size() > 1) {
+ throw new ResponseStatusException(HttpStatus.CONFLICT, "Multiple entities found");
+ }
+ return cursor.size() == 1 ? Optional.of(cursor.firstOrNull()) : Optional.empty();
+ }
+
public void delete(String uuid) {
delete(UUID.fromString(uuid));
}
diff --git a/backend/src/main/java/ai/dragon/service/ChatMessageService.java b/backend/src/main/java/ai/dragon/service/ChatMessageService.java
new file mode 100644
index 00000000..776167b2
--- /dev/null
+++ b/backend/src/main/java/ai/dragon/service/ChatMessageService.java
@@ -0,0 +1,77 @@
+package ai.dragon.service;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.stereotype.Service;
+
+import ai.dragon.dto.openai.completion.OpenAiCompletionMessage;
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.Content;
+import dev.langchain4j.data.message.ImageContent;
+import dev.langchain4j.data.message.SystemMessage;
+import dev.langchain4j.data.message.TextContent;
+import dev.langchain4j.data.message.UserMessage;
+
+@Service
+public class ChatMessageService {
+ private final Logger logger = LoggerFactory.getLogger(this.getClass());
+
+ public String singleTextFrom(UserMessage message) {
+ StringBuilder sb = new StringBuilder();
+ message.contents().forEach(content -> {
+ if (content instanceof TextContent) {
+ sb.append(((TextContent) content).text());
+ }
+ });
+ return sb.toString();
+ }
+
+ @SuppressWarnings("unchecked")
+ public Optional convertToChatMessage(OpenAiCompletionMessage completionMessage) {
+ ChatMessage chatMessage;
+ switch (completionMessage.getRole()) {
+ case "user":
+ if (completionMessage.getContent() instanceof String) {
+ // TODO name
+ chatMessage = new UserMessage((String) completionMessage.getContent());
+ } else {
+ List