Skip to content

Commit

Permalink
RaaG Init (#49)
Browse files Browse the repository at this point in the history
* OpenAiCompatibleV1ApiControllerTest : listModels() (#44) (#45)

Co-authored-by: Arnaud MENGUS <[email protected]>

* Raag init OpenAI (#46)

* Adding CheckStyle module for banning OpenAI API Keys

* RaaG : Begin with langchain4j assistant

* Init RaagService to call OpenAI LLM

* Fix OpenAiCompatibleV1ApiControllerTest

* Fix : Adding finish_reason on last chunk (#47)

* Adding CheckStyle module for banning OpenAI API Keys

* RaaG : Begin with langchain4j assistant

* Init RaagService to call OpenAI LLM

* Fix OpenAiCompatibleV1ApiControllerTest

* Adding finish_reason in order to fix : [OpenAIClient.chatCompletion][stream] Missing finish_reason

* Increase SSE timeout from 10 to 90 secs

* Increase OpenAI client timeout from 60 to 90 secs

* Begin Chat Memory (#48)

* Adding CheckStyle module for banning OpenAI API Keys

* RaaG : Begin with langchain4j assistant

* Init RaagService to call OpenAI LLM

* Fix OpenAiCompatibleV1ApiControllerTest

* Adding finish_reason in order to fix : [OpenAIClient.chatCompletion][stream] Missing finish_reason

* Increase SSE timeout from 10 to 90 secs

* Increase OpenAI client timeout from 60 to 90 secs

* Begin Chat Memory

* Ability to use the Chat Memory

* Fix checkstyle
  • Loading branch information
isontheline authored Jun 11, 2024
1 parent e159aab commit e028201
Show file tree
Hide file tree
Showing 21 changed files with 477 additions and 82 deletions.
6 changes: 5 additions & 1 deletion backend/checkstyle.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
<property name="caseIndent" value="4"/>
<property name="throwsIndent" value="4"/>
<property name="arrayInitIndent" value="4"/>
<property name="lineWrappingIndentation" value="4"/>
<property name="lineWrappingIndentation" value="8"/>
</module>
<module name="RegexpSinglelineJava">
<property name="format" value="sk\-proj\-[a-zA-Z0-9+]"/>
<property name="ignoreComments" value="false"/>
</module>
</module>
<module name="NewlineAtEndOfFile"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import java.util.UUID;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.server.ResponseStatusException;

import ai.dragon.dto.openai.completion.OpenAiChatCompletionChoice;
import ai.dragon.dto.openai.completion.OpenAiChatCompletionRequest;
Expand All @@ -19,8 +21,10 @@
import ai.dragon.dto.openai.completion.OpenAiCompletionRequest;
import ai.dragon.dto.openai.completion.OpenAiCompletionResponse;
import ai.dragon.dto.openai.completion.OpenAiCompletionUsage;
import ai.dragon.dto.openai.model.OpenAiModel;
import ai.dragon.dto.openai.model.OpenAiModelsReponse;
import ai.dragon.entity.FarmEntity;
import ai.dragon.repository.FarmRepository;
import ai.dragon.service.RaagService;
import ai.dragon.service.SseService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.media.Content;
Expand All @@ -36,16 +40,17 @@ public class OpenAiCompatibleV1ApiController {
@Autowired
private SseService sseService;

@Autowired
private FarmRepository farmRepository;

@Autowired
private RaagService raagService;

@GetMapping("/models")
@Operation(summary = "List models", description = "Lists available models.")
public OpenAiModelsReponse models() {
return OpenAiModelsReponse.builder()
.data(List.of(OpenAiModel
.builder()
.created(System.currentTimeMillis() / 1000)
.id("dragon-ppx")
.owned_by("dRAGon")
.build()))
.data(raagService.listAvailableModels())
.build();
}

Expand Down Expand Up @@ -105,63 +110,35 @@ public Object completions(@Valid @RequestBody OpenAiCompletionRequest request)
@Operation(summary = "Creates a chat completion", description = "Creates a chat completion for the provided prompt and parameters.")
@ApiResponse(responseCode = "200", content = @Content(schema = @Schema(implementation = OpenAiChatCompletionResponse.class)))
public Object chatCompletions(@Valid @RequestBody OpenAiChatCompletionRequest request) throws Exception {
FarmEntity farm = farmRepository
.findUniqueByFieldValue("raagIdentifier", request.getModel())
.orElseThrow(() -> new ResponseStatusException(HttpStatus.NOT_FOUND, "Farm not found"));
if (Boolean.TRUE.equals(request.getStream())) {
UUID emitterID = sseService.createEmitter();
for (int i = 0; i < 3; i++) {
OpenAiChatCompletionResponse responseChunk = new OpenAiChatCompletionResponse();
responseChunk.setId(emitterID.toString());
responseChunk.setModel(request.getModel());
responseChunk.setCreated(System.currentTimeMillis() / 1000);
responseChunk.setObject("chat.completion.chunk");
List<OpenAiChatCompletionChoice> choices = new ArrayList<>();
choices.add(OpenAiChatCompletionChoice
.builder()
.index(0)
.finish_reason(i == 2 ? "stop" : null)
.delta(OpenAiCompletionMessage
.builder()
.role("assistant")
.content("Chunk : " + i + "\r\n")
.build())
.build());
responseChunk.setChoices(choices);
sseService.sendEvent(emitterID, responseChunk);
}
sseService.sendEvent(emitterID, "[DONE]");
new Thread(() -> {
try {
Thread.sleep(500);
} catch (InterruptedException e) {
e.printStackTrace();
}
sseService.complete(emitterID);
}).start();
return sseService.retrieveEmitter(emitterID);
return raagService.chatResponse(farm, request);
} else {
OpenAiChatCompletionResponse response = new OpenAiChatCompletionResponse();
response.setId(UUID.randomUUID().toString());
response.setModel(request.getModel());
response.setCreated(System.currentTimeMillis() / 1000);
response.setObject("chat.completion");
response.setUsage(OpenAiCompletionUsage
.builder()
.completion_tokens(0)
.prompt_tokens(0)
.total_tokens(0)
.build());
List<OpenAiChatCompletionChoice> choices = new ArrayList<>();
choices.add(OpenAiChatCompletionChoice
return OpenAiChatCompletionResponse
.builder()
.index(0)
.finish_reason("stop")
.message(OpenAiCompletionMessage
.id(UUID.randomUUID().toString())
.model(request.getModel())
.created(System.currentTimeMillis() / 1000)
.object("chat.completion")
.usage(OpenAiCompletionUsage
.builder()
.role("assistant")
.content("Hello, how can I help you today?")
.completion_tokens(0)
.prompt_tokens(0)
.total_tokens(0)
.build())
.build());
response.setChoices(choices);
return response;
.choices(List.of(OpenAiChatCompletionChoice
.builder()
.index(0)
.finish_reason("stop")
.message(OpenAiCompletionMessage
.builder()
.role("assistant")
.content("Hello, how can I help you today?")
.build())
.build()))
.build();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package ai.dragon.dto.llm;

import java.util.function.Function;

import ai.dragon.enumeration.ProviderType;
import ai.dragon.properties.embedding.LanguageModelSettings;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import lombok.Builder;
import lombok.Data;

@Data
@Builder
public class StreamingChatLanguageModelDefinition {
private Function<LanguageModelSettings, StreamingChatLanguageModel> modelWithSettings;
private ProviderType providerType;
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.List;

import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
Expand All @@ -11,6 +12,7 @@
public class OpenAiChatCompletionRequest {
@NotBlank
@NotNull
@Schema(description = "Name of the Farm 'Raag Model' to be used.")
private String model;

@NotEmpty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import java.util.List;

import lombok.Builder;
import lombok.Data;

@Data
@Builder
public class OpenAiChatCompletionResponse {
private String id;
private String object;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package ai.dragon.dto.openai.completion;

import jakarta.validation.constraints.NotNull;

import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import lombok.Data;

@Data
public class OpenAiCompletionRequest {
@NotBlank
@NotNull
@Schema(description = "Name of the Farm 'Raag Model' to be used.")
private String model;

@NotNull
Expand Down
21 changes: 19 additions & 2 deletions backend/src/main/java/ai/dragon/entity/FarmEntity.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@
import org.dizitart.no2.repository.annotations.Index;
import org.dizitart.no2.repository.annotations.Indices;

import ai.dragon.enumeration.LanguageModelType;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Pattern;
import lombok.Getter;
import lombok.Setter;

@Entity(value = "farm")
@Schema(name = "Farm", description = "Farm Entity")
@Indices({ @Index(fields = "name", type = IndexType.UNIQUE) })
@Indices({
@Index(fields = "name", type = IndexType.UNIQUE),
@Index(fields = "raagIdentifier", type = IndexType.UNIQUE)
})
@Getter
@Setter
public class FarmEntity implements AbstractEntity {
Expand All @@ -30,14 +35,26 @@ public class FarmEntity implements AbstractEntity {
@Schema(description = "Name of the Farm. Must be unique.")
private String name;

@Schema(description = "Identifier for the 'Raag Model' to be used for the RaaG API. Must be unique")
@Pattern(regexp = "^[a-z0-9\\-]+$", message = "Must be alphanumeric, hyphens allowed")
private String raagIdentifier;

@Schema(description = """
List of Silo UUIDs to be linked to the Farm.
A farm is a collection of Silos, each Silo is a collection of Documents.""")
A Farm is a collection of Silos, which each Silo is a collection of Documents.""")
private List<UUID> silos;

@Schema(description = "Language Model to be used for the RaaG API")
private LanguageModelType languageModel;

@Schema(description = "Settings to be linked to the Farm's Language Model in the form of `key = value` pairs.")
private List<String> languageModelSettings;

public FarmEntity() {
this.uuid = UUID.randomUUID();
this.name = String.format("Farm %s", this.uuid.toString());
this.silos = new ArrayList<UUID>();
this.raagIdentifier = UUID.randomUUID().toString();
this.languageModel = LanguageModelType.OpenAiModel;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ public static EmbeddingModelType fromString(String text) {
return null;
}

@Override
public String toString() {
return value;
}

public EmbeddingModelDefinition getModelDefinition() throws ClassNotFoundException {
switch (this) {
case BgeSmallEnV15QuantizedEmbeddingModel:
Expand Down Expand Up @@ -93,9 +98,4 @@ public EmbeddingModelDefinition getModelDefinition() throws ClassNotFoundExcepti
throw new ClassNotFoundException("Model not found");
}
}

@Override
public String toString() {
return value;
}
}
51 changes: 51 additions & 0 deletions backend/src/main/java/ai/dragon/enumeration/LanguageModelType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package ai.dragon.enumeration;

import java.time.Duration;

import ai.dragon.dto.llm.StreamingChatLanguageModelDefinition;
import ai.dragon.service.SseService;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;

public enum LanguageModelType {
OpenAiModel("OpenAiModel");

private String value;

LanguageModelType(String value) {
this.value = value;
}

public static LanguageModelType fromString(String text) {
for (LanguageModelType b : LanguageModelType.values()) {
if (b.value.equalsIgnoreCase(text)) {
return b;
}
}
return null;
}

@Override
public String toString() {
return value;
}

public StreamingChatLanguageModelDefinition getStreamingChatLanguageModel() throws ClassNotFoundException {
switch (this) {
case OpenAiModel:
return StreamingChatLanguageModelDefinition
.builder()
.modelWithSettings(parameters -> {
return OpenAiStreamingChatModel
.builder()
.apiKey(parameters.getApiKey())
.modelName(parameters.getModelName())
.timeout(Duration.ofSeconds(SseService.DEFAULT_TIMEOUT))
.build();
})
.providerType(ProviderType.OpenAI)
.build();
default:
throw new ClassNotFoundException("Model not found");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public void checkIngestorLoaderSettings() throws Exception {
String[] paths = loaderSettings.getPath().trim().split(",");
pathsToIngest.clear();
for (String path : paths) {
File pathFile = new File(path);
File pathFile = new File(path.trim());
if (!pathFile.exists() || !pathFile.isDirectory()) {
logger.warn("Skipping directory because not found : {}", pathFile);
continue;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package ai.dragon.properties.embedding;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;

import lombok.Data;

@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public class LanguageModelSettings {
private String apiKey;
private String modelName;
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ public Cursor<T> findByFieldValue(String fieldName, Object fieldValue) {
return this.findWithFilter(FluentFilter.where(fieldName).eq(fieldValue));
}

public Optional<T> findUniqueByFieldValue(String fieldName, Object fieldValue) {
Cursor<T> cursor = this.findByFieldValue(fieldName, fieldValue);
if (cursor.size() > 1) {
throw new ResponseStatusException(HttpStatus.CONFLICT, "Multiple entities found");
}
return cursor.size() == 1 ? Optional.of(cursor.firstOrNull()) : Optional.empty();
}

public void delete(String uuid) {
delete(UUID.fromString(uuid));
}
Expand Down
Loading

0 comments on commit e028201

Please sign in to comment.