Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Staging #52

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion backend/checkstyle.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
<property name="caseIndent" value="4"/>
<property name="throwsIndent" value="4"/>
<property name="arrayInitIndent" value="4"/>
<property name="lineWrappingIndentation" value="4"/>
<property name="lineWrappingIndentation" value="8"/>
</module>
<module name="RegexpSinglelineJava">
<property name="format" value="sk\-proj\-[a-zA-Z0-9+]"/>
<property name="ignoreComments" value="false"/>
</module>
</module>
<module name="NewlineAtEndOfFile"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import java.util.UUID;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.server.ResponseStatusException;

import ai.dragon.dto.openai.completion.OpenAiChatCompletionChoice;
import ai.dragon.dto.openai.completion.OpenAiChatCompletionRequest;
Expand All @@ -19,8 +21,10 @@
import ai.dragon.dto.openai.completion.OpenAiCompletionRequest;
import ai.dragon.dto.openai.completion.OpenAiCompletionResponse;
import ai.dragon.dto.openai.completion.OpenAiCompletionUsage;
import ai.dragon.dto.openai.model.OpenAiModel;
import ai.dragon.dto.openai.model.OpenAiModelsReponse;
import ai.dragon.entity.FarmEntity;
import ai.dragon.repository.FarmRepository;
import ai.dragon.service.RaagService;
import ai.dragon.service.SseService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.media.Content;
Expand All @@ -36,16 +40,17 @@ public class OpenAiCompatibleV1ApiController {
@Autowired
private SseService sseService;

@Autowired
private FarmRepository farmRepository;

@Autowired
private RaagService raagService;

@GetMapping("/models")
@Operation(summary = "List models", description = "Lists available models.")
public OpenAiModelsReponse models() {
return OpenAiModelsReponse.builder()
.data(List.of(OpenAiModel
.builder()
.created(System.currentTimeMillis() / 1000)
.id("dragon-ppx")
.owned_by("dRAGon")
.build()))
.data(raagService.listAvailableModels())
.build();
}

Expand Down Expand Up @@ -105,63 +110,35 @@ public Object completions(@Valid @RequestBody OpenAiCompletionRequest request)
@Operation(summary = "Creates a chat completion", description = "Creates a chat completion for the provided prompt and parameters.")
@ApiResponse(responseCode = "200", content = @Content(schema = @Schema(implementation = OpenAiChatCompletionResponse.class)))
public Object chatCompletions(@Valid @RequestBody OpenAiChatCompletionRequest request) throws Exception {
FarmEntity farm = farmRepository
.findUniqueByFieldValue("raagIdentifier", request.getModel())
.orElseThrow(() -> new ResponseStatusException(HttpStatus.NOT_FOUND, "Farm not found"));
if (Boolean.TRUE.equals(request.getStream())) {
UUID emitterID = sseService.createEmitter();
for (int i = 0; i < 3; i++) {
OpenAiChatCompletionResponse responseChunk = new OpenAiChatCompletionResponse();
responseChunk.setId(emitterID.toString());
responseChunk.setModel(request.getModel());
responseChunk.setCreated(System.currentTimeMillis() / 1000);
responseChunk.setObject("chat.completion.chunk");
List<OpenAiChatCompletionChoice> choices = new ArrayList<>();
choices.add(OpenAiChatCompletionChoice
.builder()
.index(0)
.finish_reason(i == 2 ? "stop" : null)
.delta(OpenAiCompletionMessage
.builder()
.role("assistant")
.content("Chunk : " + i + "\r\n")
.build())
.build());
responseChunk.setChoices(choices);
sseService.sendEvent(emitterID, responseChunk);
}
sseService.sendEvent(emitterID, "[DONE]");
new Thread(() -> {
try {
Thread.sleep(500);
} catch (InterruptedException e) {
e.printStackTrace();
}
sseService.complete(emitterID);
}).start();
return sseService.retrieveEmitter(emitterID);
return raagService.chatResponse(farm, request);
} else {
OpenAiChatCompletionResponse response = new OpenAiChatCompletionResponse();
response.setId(UUID.randomUUID().toString());
response.setModel(request.getModel());
response.setCreated(System.currentTimeMillis() / 1000);
response.setObject("chat.completion");
response.setUsage(OpenAiCompletionUsage
.builder()
.completion_tokens(0)
.prompt_tokens(0)
.total_tokens(0)
.build());
List<OpenAiChatCompletionChoice> choices = new ArrayList<>();
choices.add(OpenAiChatCompletionChoice
return OpenAiChatCompletionResponse
.builder()
.index(0)
.finish_reason("stop")
.message(OpenAiCompletionMessage
.id(UUID.randomUUID().toString())
.model(request.getModel())
.created(System.currentTimeMillis() / 1000)
.object("chat.completion")
.usage(OpenAiCompletionUsage
.builder()
.role("assistant")
.content("Hello, how can I help you today?")
.completion_tokens(0)
.prompt_tokens(0)
.total_tokens(0)
.build())
.build());
response.setChoices(choices);
return response;
.choices(List.of(OpenAiChatCompletionChoice
.builder()
.index(0)
.finish_reason("stop")
.message(OpenAiCompletionMessage
.builder()
.role("assistant")
.content("Hello, how can I help you today?")
.build())
.build()))
.build();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package ai.dragon.dto.llm;

import java.util.function.Function;

import ai.dragon.enumeration.ProviderType;
import ai.dragon.properties.embedding.LanguageModelSettings;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import lombok.Builder;
import lombok.Data;

@Data
@Builder
public class StreamingChatLanguageModelDefinition {
private Function<LanguageModelSettings, StreamingChatLanguageModel> modelWithSettings;
private ProviderType providerType;
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.List;

import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
Expand All @@ -11,6 +12,7 @@
public class OpenAiChatCompletionRequest {
@NotBlank
@NotNull
@Schema(description = "Name of the Farm 'Raag Model' to be used.")
private String model;

@NotEmpty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import java.util.List;

import lombok.Builder;
import lombok.Data;

@Data
@Builder
public class OpenAiChatCompletionResponse {
private String id;
private String object;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package ai.dragon.dto.openai.completion;

import jakarta.validation.constraints.NotNull;

import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import lombok.Data;

@Data
public class OpenAiCompletionRequest {
@NotBlank
@NotNull
@Schema(description = "Name of the Farm 'Raag Model' to be used.")
private String model;

@NotNull
Expand Down
21 changes: 19 additions & 2 deletions backend/src/main/java/ai/dragon/entity/FarmEntity.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@
import org.dizitart.no2.repository.annotations.Index;
import org.dizitart.no2.repository.annotations.Indices;

import ai.dragon.enumeration.LanguageModelType;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Pattern;
import lombok.Getter;
import lombok.Setter;

@Entity(value = "farm")
@Schema(name = "Farm", description = "Farm Entity")
@Indices({ @Index(fields = "name", type = IndexType.UNIQUE) })
@Indices({
@Index(fields = "name", type = IndexType.UNIQUE),
@Index(fields = "raagIdentifier", type = IndexType.UNIQUE)
})
@Getter
@Setter
public class FarmEntity implements AbstractEntity {
Expand All @@ -30,14 +35,26 @@ public class FarmEntity implements AbstractEntity {
@Schema(description = "Name of the Farm. Must be unique.")
private String name;

@Schema(description = "Identifier for the 'Raag Model' to be used for the RaaG API. Must be unique")
@Pattern(regexp = "^[a-z0-9\\-]+$", message = "Must be alphanumeric, hyphens allowed")
private String raagIdentifier;

@Schema(description = """
List of Silo UUIDs to be linked to the Farm.
A farm is a collection of Silos, each Silo is a collection of Documents.""")
A Farm is a collection of Silos, which each Silo is a collection of Documents.""")
private List<UUID> silos;

@Schema(description = "Language Model to be used for the RaaG API")
private LanguageModelType languageModel;

@Schema(description = "Settings to be linked to the Farm's Language Model in the form of `key = value` pairs.")
private List<String> languageModelSettings;

public FarmEntity() {
this.uuid = UUID.randomUUID();
this.name = String.format("Farm %s", this.uuid.toString());
this.silos = new ArrayList<UUID>();
this.raagIdentifier = UUID.randomUUID().toString();
this.languageModel = LanguageModelType.OpenAiModel;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ public static EmbeddingModelType fromString(String text) {
return null;
}

@Override
public String toString() {
return value;
}

public EmbeddingModelDefinition getModelDefinition() throws ClassNotFoundException {
switch (this) {
case BgeSmallEnV15QuantizedEmbeddingModel:
Expand Down Expand Up @@ -93,9 +98,4 @@ public EmbeddingModelDefinition getModelDefinition() throws ClassNotFoundExcepti
throw new ClassNotFoundException("Model not found");
}
}

@Override
public String toString() {
return value;
}
}
51 changes: 51 additions & 0 deletions backend/src/main/java/ai/dragon/enumeration/LanguageModelType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package ai.dragon.enumeration;

import java.time.Duration;

import ai.dragon.dto.llm.StreamingChatLanguageModelDefinition;
import ai.dragon.service.SseService;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;

public enum LanguageModelType {
OpenAiModel("OpenAiModel");

private String value;

LanguageModelType(String value) {
this.value = value;
}

public static LanguageModelType fromString(String text) {
for (LanguageModelType b : LanguageModelType.values()) {
if (b.value.equalsIgnoreCase(text)) {
return b;
}
}
return null;
}

@Override
public String toString() {
return value;
}

public StreamingChatLanguageModelDefinition getStreamingChatLanguageModel() throws ClassNotFoundException {
switch (this) {
case OpenAiModel:
return StreamingChatLanguageModelDefinition
.builder()
.modelWithSettings(parameters -> {
return OpenAiStreamingChatModel
.builder()
.apiKey(parameters.getApiKey())
.modelName(parameters.getModelName())
.timeout(Duration.ofSeconds(SseService.DEFAULT_TIMEOUT))
.build();
})
.providerType(ProviderType.OpenAI)
.build();
default:
throw new ClassNotFoundException("Model not found");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public void checkIngestorLoaderSettings() throws Exception {
String[] paths = loaderSettings.getPath().trim().split(",");
pathsToIngest.clear();
for (String path : paths) {
File pathFile = new File(path);
File pathFile = new File(path.trim());
if (!pathFile.exists() || !pathFile.isDirectory()) {
logger.warn("Skipping directory because not found : {}", pathFile);
continue;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package ai.dragon.properties.embedding;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;

import lombok.Data;

@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public class LanguageModelSettings {
private String apiKey;
private String modelName;
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ public Cursor<T> findByFieldValue(String fieldName, Object fieldValue) {
return this.findWithFilter(FluentFilter.where(fieldName).eq(fieldValue));
}

public Optional<T> findUniqueByFieldValue(String fieldName, Object fieldValue) {
Cursor<T> cursor = this.findByFieldValue(fieldName, fieldValue);
if (cursor.size() > 1) {
throw new ResponseStatusException(HttpStatus.CONFLICT, "Multiple entities found");
}
return cursor.size() == 1 ? Optional.of(cursor.firstOrNull()) : Optional.empty();
}

public void delete(String uuid) {
delete(UUID.fromString(uuid));
}
Expand Down
Loading