Skip to content

Commit

Permalink
Merge pull request #183 from dRAGon-Okinawa/staging
Browse files Browse the repository at this point in the history
Staging
  • Loading branch information
amengus87 authored Jul 26, 2024
2 parents fcf6ceb + dfa4b07 commit 92176cb
Show file tree
Hide file tree
Showing 8 changed files with 329 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on:
- workflow_dispatch
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: false
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
build:
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
![Build dRAGon Project](https://github.com/dragon-okinawa/dragon/actions/workflows/build.yml/badge.svg?branch=main)
[![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=dRAGon-Okinawa_dRAGon&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=dRAGon-Okinawa_dRAGon)
[![Vulnerabilities](https://sonarcloud.io/api/project_badges/measure?project=dRAGon-Okinawa_dRAGon&metric=vulnerabilities)](https://sonarcloud.io/summary/new_code?id=dRAGon-Okinawa_dRAGon)
[![Bugs](https://sonarcloud.io/api/project_badges/measure?project=dRAGon-Okinawa_dRAGon&metric=bugs)](https://sonarcloud.io/summary/new_code?id=dRAGon-Okinawa_dRAGon)
Expand Down
29 changes: 27 additions & 2 deletions backend/build.gradle
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import org.gradle.api.tasks.testing.logging.TestExceptionFormat
import org.gradle.api.tasks.testing.logging.TestLogEvent

plugins {
id 'java'
id 'java-library'
id 'jacoco'
id 'checkstyle'
id 'application'
id 'org.springframework.boot' version '3.3.0'
id 'org.springframework.boot' version '3.3.2'
id 'io.spring.dependency-management' version '1.1.6'
id 'org.springdoc.openapi-gradle-plugin' version '1.8.0'
}
Expand All @@ -24,7 +27,7 @@ dependencies {
annotationProcessor "org.springframework.boot:spring-boot-configuration-processor"
implementation 'org.springframework.boot:spring-boot-starter-security'
implementation 'org.springframework.boot:spring-boot-starter-web'
implementation 'org.springframework.boot:spring-boot-starter-thymeleaf:3.2.5'
implementation 'org.springframework.boot:spring-boot-starter-thymeleaf:3.3.2'
implementation 'org.springframework.boot:spring-boot-starter-validation'
implementation 'org.hibernate.validator:hibernate-validator'
developmentOnly 'org.springframework.boot:spring-boot-devtools'
Expand Down Expand Up @@ -144,3 +147,25 @@ tasks.named('compileJava') {
tasks.named('processResources') {
dependsOn('copyWebApp')
}

// Based on : https://stackoverflow.com/a/36130467/8102448
tasks.withType(Test) {
testLogging {
events TestLogEvent.FAILED,
TestLogEvent.PASSED,
TestLogEvent.SKIPPED
exceptionFormat TestExceptionFormat.FULL
showExceptions true
showCauses true
showStackTraces true

afterSuite { desc, result ->
if (!desc.parent) {
def output = "Results: ${result.resultType} (${result.testCount} tests, ${result.successfulTestCount} passed, ${result.failedTestCount} failed, ${result.skippedTestCount} skipped)"
def startItem = '| ', endItem = ' |'
def repeatLength = startItem.length() + output.length() + endItem.length()
println('\n' + ('-' * repeatLength) + '\n' + startItem + output + endItem + '\n' + ('-' * repeatLength))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.File;
import java.io.InterruptedIOException;
import java.net.SocketTimeoutException;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.UUID;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIf;
import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -19,10 +25,20 @@
import org.springframework.test.context.ActiveProfiles;

import ai.dragon.entity.FarmEntity;
import ai.dragon.entity.SiloEntity;
import ai.dragon.enumeration.EmbeddingModelType;
import ai.dragon.enumeration.IngestorLoaderType;
import ai.dragon.enumeration.LanguageModelType;
import ai.dragon.enumeration.VectorStoreType;
import ai.dragon.junit.AbstractTest;
import ai.dragon.junit.extension.retry.RetryOnExceptions;
import ai.dragon.repository.FarmRepository;
import ai.dragon.test.AbstractTest;
import ai.dragon.repository.SiloRepository;
import ai.dragon.service.IngestorService;
import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.completion.CompletionRequest;
import dev.ai4j.openai4j.completion.CompletionResponse;
import dev.langchain4j.model.mistralai.internal.api.MistralAiModelResponse;
Expand All @@ -34,67 +50,247 @@ public class OpenAiCompatibleV1ApiControllerTest extends AbstractTest {
@LocalServerPort
private int serverPort;

@Autowired
private FarmRepository farmRepository;

@BeforeAll
static void beforeAll(@Autowired FarmRepository farmRepository) {
farmRepository.deleteAll();
static void beforeAll(@Autowired FarmRepository farmRepository,
@Autowired SiloRepository siloRepository,
@Autowired IngestorService ingestorService) throws Exception {
cleanUp(farmRepository, siloRepository);

// OpenAI settings for RaaG
String apiKeySetting = String.format("apiKey=%s", System.getenv("OPENAI_API_KEY"));
String modelNameSetting = "modelName=gpt-4o";

// Farm with no silo
FarmEntity farmWithoutSilo = new FarmEntity();
farmWithoutSilo.setRaagIdentifier("no-silo-raag");
farmWithoutSilo.setLanguageModel(LanguageModelType.OpenAiModel);
farmWithoutSilo.setLanguageModelSettings(List.of(apiKeySetting, modelNameSetting));
farmRepository.save(farmWithoutSilo);

// Resources for the silo
String ragResourcesPath = "src/test/resources/rag_documents/sunspots";
File ragResources = new File(ragResourcesPath);
String ragResourcesAbsolutePath = ragResources.getAbsolutePath();
assertNotNull(ragResourcesAbsolutePath);

// Silo about "Sunspots"
SiloEntity sunspotsSilo = new SiloEntity();
sunspotsSilo.setUuid(UUID.randomUUID());
sunspotsSilo.setName("Sunspots Silo");
sunspotsSilo.setEmbeddingModel(EmbeddingModelType.BgeSmallEnV15QuantizedEmbeddingModel);
sunspotsSilo.setEmbeddingSettings(List.of(
"chunkSize=1000",
"chunkOverlap=100"));
sunspotsSilo.setVectorStore(VectorStoreType.InMemoryEmbeddingStore);
sunspotsSilo.setIngestorLoader(IngestorLoaderType.FileSystem);
sunspotsSilo.setIngestorSettings(List.of(
String.format("paths[]=%s", ragResourcesAbsolutePath),
"recursive=false",
"pathMatcher=glob:**.{pdf,doc,docx,ppt,pptx}"));
siloRepository.save(sunspotsSilo);

// Launching ingestion of documents inside the Silo
ingestorService.runSiloIngestion(sunspotsSilo, ingestProgress -> {
System.out.println("Ingest progress: " + ingestProgress);
}, ingestLogMessage -> {
System.out.println(ingestLogMessage.getMessage());
});

// Farm with the Sunspots Silo
FarmEntity farmWithSunspotsSilo = new FarmEntity();
farmWithSunspotsSilo.setRaagIdentifier("sunspots-raag");
farmWithSunspotsSilo.setLanguageModel(LanguageModelType.OpenAiModel);
farmWithSunspotsSilo.setLanguageModelSettings(List.of(apiKeySetting, modelNameSetting));
farmWithSunspotsSilo.setSilos(List.of(sunspotsSilo.getUuid()));
farmRepository.save(farmWithSunspotsSilo);

// Farm with the Sunspots Silo but with Query Rewriting
FarmEntity farmWithSunspotsSiloAndQueryRewriting = new FarmEntity();
farmWithSunspotsSiloAndQueryRewriting.setRaagIdentifier("sunspots-rewriting-raag");
farmWithSunspotsSiloAndQueryRewriting.setLanguageModel(LanguageModelType.OpenAiModel);
farmWithSunspotsSiloAndQueryRewriting.setLanguageModelSettings(List.of(apiKeySetting, modelNameSetting));
farmWithSunspotsSiloAndQueryRewriting.setSilos(List.of(sunspotsSilo.getUuid()));
farmWithSunspotsSiloAndQueryRewriting.setRetrievalAugmentorSettings(List.of("rewriteQuery=true"));
farmRepository.save(farmWithSunspotsSiloAndQueryRewriting);
}

@AfterAll
static void afterAll(@Autowired FarmRepository farmRepository) {
farmRepository.deleteAll();
static void afterAll(@Autowired FarmRepository farmRepository, @Autowired SiloRepository siloRepository) {
cleanUp(farmRepository, siloRepository);
}

@BeforeEach
void beforeEach(@Autowired FarmRepository farmRepository) {
static void cleanUp(FarmRepository farmRepository, SiloRepository siloRepository) {
farmRepository.deleteAll();
siloRepository.deleteAll();
}

@Test
void listModels() throws Exception {
FarmEntity farm = new FarmEntity();
farm.setRaagIdentifier("awesome-raag");
farmRepository.save(farm);
@SuppressWarnings("rawtypes")
private OpenAiClient.Builder createOpenAiClientBuilder() {
return OpenAiClient.builder()
.openAiApiKey("TODO_PUT_KEY_HERE")
.baseUrl(String.format("http://localhost:%d/api/raag/v1/", serverPort))
.callTimeout(Duration.ofSeconds(10))
.readTimeout(Duration.ofSeconds(10))
.writeTimeout(Duration.ofSeconds(10))
.connectTimeout(Duration.ofSeconds(10));
}

MistralAiClient client = MistralAiClient.builder()
@SuppressWarnings("rawtypes")
private MistralAiClient.Builder createMistralAiClientBuilder() {
return MistralAiClient.builder()
.apiKey("TODO_PUT_KEY_HERE")
.baseUrl(String.format("http://localhost:%d/api/raag/v1/", serverPort))
.timeout(Duration.ofSeconds(10))
.logRequests(false)
.logResponses(false)
.build();
.logResponses(false);
}

@Test
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void listModels() throws Exception {
MistralAiClient client = createMistralAiClientBuilder().build();
MistralAiModelResponse modelsResponse = client.listModels();
assertNotNull(modelsResponse);
assertEquals(1, modelsResponse.getData().size());
assertEquals(farm.getRaagIdentifier(), modelsResponse.getData().get(0).getId());
assertNotEquals(0, modelsResponse.getData().size());
}

@Test
@EnabledIf("canRunOpenAiRelatedTests")
void testFarmNoSiloOpenAI() {
String apiKeySetting = String.format("apiKey=%s", System.getenv("OPENAI_API_KEY"));
String modelNameSetting = "modelName=gpt-4o";

FarmEntity farm = new FarmEntity();
farm.setRaagIdentifier("dragon-raag");
farm.setLanguageModel(LanguageModelType.OpenAiModel);
farm.setLanguageModelSettings(List.of(apiKeySetting, modelNameSetting));
farmRepository.save(farm);

OpenAiClient client = OpenAiClient.builder()
.openAiApiKey("TODO_PUT_KEY_HERE")
.baseUrl(String.format("http://localhost:%d/api/raag/v1/", serverPort))
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void testModelDoesntExistOpenAI() {
OpenAiClient client = createOpenAiClientBuilder().build();
CompletionRequest request = CompletionRequest.builder()
.model("should-not-exist")
.prompt("Just say : 'dRAGon'")
.build();
OpenAiHttpException exception = assertThrows(OpenAiHttpException.class,
() -> client.completion(request).execute());
assertTrue(exception.code() == 404);
}

@Test
@EnabledIf("canRunOpenAiRelatedTests")
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void testFarmNoSiloOpenAI() {
OpenAiClient client = createOpenAiClientBuilder().build();
CompletionRequest request = CompletionRequest.builder()
.model("dragon-raag")
.model("no-silo-raag")
.prompt("Just say 'HELLO' in lowercased letters.")
.temperature(0.0)
.build();
CompletionResponse response = client.completion(request).execute();
assertNotNull(response);
assertNotNull(response.choices());
assertNotEquals(0, response.choices().size());
assertEquals("hello", response.choices().get(0).text());
}

@Test
@EnabledIf("canRunOpenAiRelatedTests")
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void testFarmCompletionOpenAI() {
OpenAiClient client = createOpenAiClientBuilder().build();
CompletionRequest request = CompletionRequest.builder()
.model("sunspots-raag")
.prompt("Who is the author of document 'The Size of the Carrington Event Sunspot Group'? Just reply with the firstname and lastname.")
.stream(false)
.temperature(0.0)
.build();
CompletionResponse response = client.completion(request).execute();
assertNotNull(response);
assertNotNull(response.choices());
assertNotEquals(0, response.choices().size());
assertEquals("Peter Meadows", response.choices().get(0).text());
}

@Test
@EnabledIf("canRunOpenAiRelatedTests")
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void testFarmCompletionStreamOpenAI() {
OpenAiClient client = createOpenAiClientBuilder().build();
CompletionRequest request = CompletionRequest.builder()
.model("sunspots-raag")
.prompt("""
Who is the author of document 'Sunspots, unemployment, and recessions, or Can the solar activity cycle shape the business cycle?'?
Just reply with the firstname and lastname.
""")
.stream(true)
.temperature(0.0)
.build();
CompletionResponse response = client.completion(request).execute();
assertNotNull(response);
assertNotNull(response.choices());
assertNotEquals(0, response.choices().size());
assertEquals("Mikhail Gorbanev", response.choices().get(0).text());
}

@Test
@EnabledIf("canRunOpenAiRelatedTests")
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void testFarmChatRewriteQueryOpenAI() {
OpenAiClient client = createOpenAiClientBuilder().build();
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.addSystemMessage(
"You are a researcher in solar physics and you provide help to other researchers.")
.addUserMessage(
"Hello, I am looking for the author of the document 'The Size of the Carrington Event Sunspot Group'.")
.addUserMessage(
"""
Can you help me?
* Just use the context of this message and nothing else to reply.
* If the information is not provided, just say 'I do not know!'
* Just reply with the firstname and lastname.
""")
.temperature(0.0);

for (int i = 0; i <= 1; i++) {
for (int j = 0; j <= 1; j++) {
ChatCompletionRequest request = requestBuilder
.model(i == 0 ? "sunspots-raag" : "sunspots-rewriting-raag")
.stream(j == 1)
.build();
ChatCompletionResponse response = client.chatCompletion(request).execute();
assertNotNull(response);
assertNotNull(response.choices());
assertNotEquals(0, response.choices().size());
assertEquals(i == 0 ? "I do not know!" : "Peter Meadows", response.content());
}
}
}

@Test
@EnabledIf("canRunOpenAiRelatedTests")
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
@SuppressWarnings("unchecked")
void testFarmCompletionWithMetadataFilterOpenAI() {
Map.of(
"non_existing_document.pdf", false,
"BAAJournalCarringtonEventPaper_compressed.pdf", true)
.forEach((documentName, expected) -> {
Map<String, String> customHeaders = Map.of(
"X-RAG-FILTER-METADATA",
String.format("{{#metadataKey('document_name').isIn('%s')}}", documentName));
OpenAiClient client = createOpenAiClientBuilder()
.customHeaders(customHeaders)
.build();
CompletionRequest request = CompletionRequest.builder()
.model("sunspots-raag")
.prompt("""
Who is the author of document 'The Size of the Carrington Event Sunspot Group'?
Just reply with the firstname and lastname.
""")
.stream(false)
.temperature(0.0)
.build();
CompletionResponse response = client.completion(request).execute();
assertNotNull(response);
assertNotNull(response.choices());
assertNotEquals(0, response.choices().size());
if (expected) {
assertEquals("Peter Meadows", response.choices().get(0).text());
} else {
assertNotEquals("Peter Meadows", response.choices().get(0).text());
}
});
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ai.dragon.test;
package ai.dragon.junit;

public abstract class AbstractTest {
protected boolean canRunOpenAiRelatedTests() {
Expand Down
Loading

0 comments on commit 92176cb

Please sign in to comment.