Skip to content

Commit

Permalink
Init RaagService to call OpenAI LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
amengus87 committed Jun 11, 2024
1 parent 40a7f67 commit 66f34bf
Show file tree
Hide file tree
Showing 15 changed files with 302 additions and 100 deletions.
2 changes: 1 addition & 1 deletion backend/checkstyle.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
<property name="caseIndent" value="4"/>
<property name="throwsIndent" value="4"/>
<property name="arrayInitIndent" value="4"/>
<property name="lineWrappingIndentation" value="4"/>
<property name="lineWrappingIndentation" value="8"/>
</module>
<module name="RegexpSinglelineJava">
<property name="format" value="sk\-proj\-[a-zA-Z0-9+]"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,11 @@
import ai.dragon.dto.openai.completion.OpenAiCompletionRequest;
import ai.dragon.dto.openai.completion.OpenAiCompletionResponse;
import ai.dragon.dto.openai.completion.OpenAiCompletionUsage;
import ai.dragon.dto.openai.model.OpenAiModel;
import ai.dragon.dto.openai.model.OpenAiModelsReponse;
import ai.dragon.entity.SiloEntity;
import ai.dragon.repository.SiloRepository;
import ai.dragon.service.EmbeddingModelService;
import ai.dragon.service.EmbeddingStoreService;
import ai.dragon.entity.FarmEntity;
import ai.dragon.repository.FarmRepository;
import ai.dragon.service.RaagService;
import ai.dragon.service.SseService;
import ai.dragon.util.ai.AiAssistant;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.store.embedding.EmbeddingStore;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
Expand All @@ -52,24 +41,16 @@ public class OpenAiCompatibleV1ApiController {
private SseService sseService;

@Autowired
private EmbeddingStoreService embeddingStoreService;
private FarmRepository farmRepository;

@Autowired
private EmbeddingModelService embeddingModelService;

@Autowired
private SiloRepository siloRepository;
private RaagService raagService;

@GetMapping("/models")
@Operation(summary = "List models", description = "Lists available models.")
public OpenAiModelsReponse models() {
return OpenAiModelsReponse.builder()
.data(List.of(OpenAiModel
.builder()
.created(System.currentTimeMillis() / 1000)
.id("dragon-ppx")
.owned_by("dRAGon")
.build()))
.data(raagService.listAvailableModels())
.build();
}

Expand Down Expand Up @@ -129,67 +110,11 @@ public Object completions(@Valid @RequestBody OpenAiCompletionRequest request)
@Operation(summary = "Creates a chat completion", description = "Creates a chat completion for the provided prompt and parameters.")
@ApiResponse(responseCode = "200", content = @Content(schema = @Schema(implementation = OpenAiChatCompletionResponse.class)))
public Object chatCompletions(@Valid @RequestBody OpenAiChatCompletionRequest request) throws Exception {
FarmEntity farm = farmRepository
.findUniqueByFieldValue("raagIdentifier", request.getModel())
.orElseThrow(() -> new ResponseStatusException(HttpStatus.NOT_FOUND, "Farm not found"));
if (Boolean.TRUE.equals(request.getStream())) {
UUID siloUuid = UUID.fromString(request.getModel());
SiloEntity silo = siloRepository.getByUuid(siloUuid)
.orElseThrow(() -> new ResponseStatusException(HttpStatus.NOT_FOUND, "Entity not found"));
UUID emitterID = sseService.createEmitter();

EmbeddingStore<TextSegment> embeddingStore = embeddingStoreService.retrieveEmbeddingStore(siloUuid);
EmbeddingModel embeddingModel = embeddingModelService.modelForEntity(silo);
ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.dynamicMaxResults(query -> {
return 10;
})
.dynamicMinScore(query -> {
return 0.8;
})
.dynamicFilter(query -> {
return null;
})
.build();
AiAssistant assistant = AiServices.builder(AiAssistant.class)
.streamingChatLanguageModel(OpenAiStreamingChatModel
.withApiKey("TODO"))
// .chatLanguageModel(
// OpenAiChatModel.withApiKey("TODO"))
// .chatMemory(MessageWindowChatMemory.withMaxMessages(10)) // it should
// remember 10 latest messages
// TODO : .retrievalAugmentor(retrievalAugmentor)
.contentRetriever(contentRetriever)
.build();
TokenStream stream = assistant.chat((String) request.getMessages().get(0).getContent());

stream.onNext(nextChunk -> {
OpenAiChatCompletionResponse responseChunk = new OpenAiChatCompletionResponse();
responseChunk.setId(emitterID.toString());
responseChunk.setModel(request.getModel());
responseChunk.setCreated(System.currentTimeMillis() / 1000);
responseChunk.setObject("chat.completion.chunk");
List<OpenAiChatCompletionChoice> choices = new ArrayList<>();
choices.add(OpenAiChatCompletionChoice
.builder()
.index(0)
// .finish_reason(i == 2 ? "stop" : null)
.delta(OpenAiCompletionMessage
.builder()
.role("assistant")
.content(nextChunk)
.build())
.build());
responseChunk.setChoices(choices);
sseService.sendEvent(emitterID, responseChunk);
})
.onComplete(response -> {
// response.finishReason(). // "stop"
sseService.sendEvent(emitterID, "[DONE]");
sseService.complete(emitterID);
})
.onError(Throwable::printStackTrace)
.start();
return sseService.retrieveEmitter(emitterID);
return raagService.chatResponse(farm, request);
} else {
OpenAiChatCompletionResponse response = new OpenAiChatCompletionResponse();
response.setId(UUID.randomUUID().toString());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package ai.dragon.dto.llm;

import java.util.function.Function;

import ai.dragon.enumeration.ProviderType;
import ai.dragon.properties.embedding.LanguageModelSettings;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import lombok.Builder;
import lombok.Data;

@Data
@Builder
public class StreamingChatLanguageModelDefinition {
private Function<LanguageModelSettings, StreamingChatLanguageModel> modelWithSettings;
private ProviderType providerType;
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.List;

import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
Expand All @@ -11,6 +12,7 @@
public class OpenAiChatCompletionRequest {
@NotBlank
@NotNull
@Schema(description = "Name of the Farm 'Raag Model' to be used.")
private String model;

@NotEmpty
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package ai.dragon.dto.openai.completion;

import jakarta.validation.constraints.NotNull;

import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import lombok.Data;

@Data
public class OpenAiCompletionRequest {
@NotBlank
@NotNull
@Schema(description = "Name of the Farm 'Raag Model' to be used.")
private String model;

@NotNull
Expand Down
21 changes: 19 additions & 2 deletions backend/src/main/java/ai/dragon/entity/FarmEntity.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@
import org.dizitart.no2.repository.annotations.Index;
import org.dizitart.no2.repository.annotations.Indices;

import ai.dragon.enumeration.LanguageModelType;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Pattern;
import lombok.Getter;
import lombok.Setter;

@Entity(value = "farm")
@Schema(name = "Farm", description = "Farm Entity")
@Indices({ @Index(fields = "name", type = IndexType.UNIQUE) })
@Indices({
@Index(fields = "name", type = IndexType.UNIQUE),
@Index(fields = "raagIdentifier", type = IndexType.UNIQUE)
})
@Getter
@Setter
public class FarmEntity implements AbstractEntity {
Expand All @@ -30,14 +35,26 @@ public class FarmEntity implements AbstractEntity {
@Schema(description = "Name of the Farm. Must be unique.")
private String name;

@Schema(description = "Identifier for the 'Raag Model' to be used for the RaaG API. Must be unique")
@Pattern(regexp = "^[a-z0-9\\-]+$", message = "Must be alphanumeric, hyphens allowed")
private String raagIdentifier;

@Schema(description = """
List of Silo UUIDs to be linked to the Farm.
A farm is a collection of Silos, each Silo is a collection of Documents.""")
A Farm is a collection of Silos, which each Silo is a collection of Documents.""")
private List<UUID> silos;

@Schema(description = "Language Model to be used for the RaaG API")
private LanguageModelType languageModel;

@Schema(description = "Settings to be linked to the Farm's Language Model in the form of `key = value` pairs.")
private List<String> languageModelSettings;

public FarmEntity() {
this.uuid = UUID.randomUUID();
this.name = String.format("Farm %s", this.uuid.toString());
this.silos = new ArrayList<UUID>();
this.raagIdentifier = UUID.randomUUID().toString();
this.languageModel = LanguageModelType.OpenAiModel;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ public static EmbeddingModelType fromString(String text) {
return null;
}

@Override
public String toString() {
return value;
}

public EmbeddingModelDefinition getModelDefinition() throws ClassNotFoundException {
switch (this) {
case BgeSmallEnV15QuantizedEmbeddingModel:
Expand Down Expand Up @@ -93,9 +98,4 @@ public EmbeddingModelDefinition getModelDefinition() throws ClassNotFoundExcepti
throw new ClassNotFoundException("Model not found");
}
}

@Override
public String toString() {
return value;
}
}
47 changes: 47 additions & 0 deletions backend/src/main/java/ai/dragon/enumeration/LanguageModelType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package ai.dragon.enumeration;

import ai.dragon.dto.llm.StreamingChatLanguageModelDefinition;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;

public enum LanguageModelType {
OpenAiModel("OpenAiModel");

private String value;

LanguageModelType(String value) {
this.value = value;
}

public static LanguageModelType fromString(String text) {
for (LanguageModelType b : LanguageModelType.values()) {
if (b.value.equalsIgnoreCase(text)) {
return b;
}
}
return null;
}

@Override
public String toString() {
return value;
}

public StreamingChatLanguageModelDefinition getStreamingChatLanguageModel() throws ClassNotFoundException {
switch (this) {
case OpenAiModel:
return StreamingChatLanguageModelDefinition
.builder()
.modelWithSettings(parameters -> {
return OpenAiStreamingChatModel
.builder()
.apiKey(parameters.getApiKey())
.modelName(parameters.getModelName())
.build();
})
.providerType(ProviderType.OpenAI)
.build();
default:
throw new ClassNotFoundException("Model not found");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package ai.dragon.properties.embedding;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;

import lombok.Data;

@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public class LanguageModelSettings {
private String apiKey;
private String modelName;
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ public Cursor<T> findByFieldValue(String fieldName, Object fieldValue) {
return this.findWithFilter(FluentFilter.where(fieldName).eq(fieldValue));
}

public Optional<T> findUniqueByFieldValue(String fieldName, Object fieldValue) {
Cursor<T> cursor = this.findByFieldValue(fieldName, fieldValue);
if (cursor.size() > 1) {
throw new ResponseStatusException(HttpStatus.CONFLICT, "Multiple entities found");
}
return cursor.size() == 1 ? Optional.of(cursor.firstOrNull()) : Optional.empty();
}

public void delete(String uuid) {
delete(UUID.fromString(uuid));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,32 @@
package ai.dragon.service;

import java.util.UUID;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.web.server.ResponseStatusException;

import ai.dragon.entity.SiloEntity;
import ai.dragon.properties.embedding.EmbeddingSettings;
import ai.dragon.repository.SiloRepository;
import dev.langchain4j.model.embedding.EmbeddingModel;

@Service
public class EmbeddingModelService {
@Autowired
private KVSettingService kvSettingService;

public EmbeddingModel modelForEntity(SiloEntity siloEntity) throws Exception {
@Autowired
private SiloRepository siloRepository;

public EmbeddingModel modelForSilo(UUID siloUuid) throws Exception {
SiloEntity siloEntity = siloRepository.getByUuid(siloUuid)
.orElseThrow(() -> new ResponseStatusException(HttpStatus.NOT_FOUND, "Silo not found"));
return modelForSilo(siloEntity);
}

public EmbeddingModel modelForSilo(SiloEntity siloEntity) throws Exception {
EmbeddingSettings embeddingSettings = kvSettingService.kvSettingsToObject(
siloEntity.getEmbeddingSettings(), EmbeddingSettings.class);
return siloEntity.getEmbeddingModel().getModelDefinition().getEmbeddingModelWithSettings()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public void clearEmbeddingStore(UUID siloUuid) throws Exception {
public EmbeddingSearchResult<TextSegment> query(UUID siloUuid, String query, Integer maxResults) throws Exception {
SiloEntity siloEntity = siloRepository.getByUuid(siloUuid).orElseThrow();
EmbeddingStore<TextSegment> embeddingStore = retrieveEmbeddingStore(siloUuid);
EmbeddingModel embeddingModel = embeddingModelService.modelForEntity(siloEntity);
EmbeddingModel embeddingModel = embeddingModelService.modelForSilo(siloEntity);
Embedding queryEmbedding = embeddingModel.embed(query).content();
// Filter onlyForUser1 = metadataKey("userId").isEqualTo("1");
EmbeddingSearchRequest embeddingSearchRequest1 = EmbeddingSearchRequest.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private void ingestDocumentsToSilo(List<Document> documents, SiloEntity siloEnti
throws Exception {
EmbeddingStore<TextSegment> embeddingStore = embeddingStoreService
.retrieveEmbeddingStore(siloEntity.getUuid());
EmbeddingModel embeddingModel = embeddingModelService.modelForEntity(siloEntity);
EmbeddingModel embeddingModel = embeddingModelService.modelForSilo(siloEntity);
EmbeddingStoreIngestor ingestor = buildIngestor(embeddingStore, embeddingModel, siloEntity);
logCallback.accept(SiloIngestLoaderLogMessage.builder()
.message(String.format(
Expand Down
Loading

0 comments on commit 66f34bf

Please sign in to comment.