diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9c8b8d6a..47863c15 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -4,7 +4,7 @@ on: - workflow_dispatch concurrency: group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: false + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} jobs: build: runs-on: ubuntu-latest diff --git a/README.md b/README.md index 6b28e0ba..37ec91a9 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ +![Build dRAGon Project](https://github.com/dragon-okinawa/dragon/actions/workflows/build.yml/badge.svg?branch=main) [![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=dRAGon-Okinawa_dRAGon&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=dRAGon-Okinawa_dRAGon) [![Vulnerabilities](https://sonarcloud.io/api/project_badges/measure?project=dRAGon-Okinawa_dRAGon&metric=vulnerabilities)](https://sonarcloud.io/summary/new_code?id=dRAGon-Okinawa_dRAGon) [![Bugs](https://sonarcloud.io/api/project_badges/measure?project=dRAGon-Okinawa_dRAGon&metric=bugs)](https://sonarcloud.io/summary/new_code?id=dRAGon-Okinawa_dRAGon) diff --git a/backend/build.gradle b/backend/build.gradle index c45f141b..ee5005a1 100644 --- a/backend/build.gradle +++ b/backend/build.gradle @@ -1,3 +1,6 @@ +import org.gradle.api.tasks.testing.logging.TestExceptionFormat +import org.gradle.api.tasks.testing.logging.TestLogEvent + plugins { id 'java' id 'java-library' @@ -6,7 +9,7 @@ plugins { id 'application' id 'org.springframework.boot' version '3.3.2' id 'io.spring.dependency-management' version '1.1.6' - id 'org.springdoc.openapi-gradle-plugin' version '1.8.0' + id 'org.springdoc.openapi-gradle-plugin' version '1.9.0' } group = 'ai' @@ -141,3 +144,25 @@ tasks.named('compileJava') { tasks.named('processResources') { dependsOn('copyWebApp') } + +// Based on : https://stackoverflow.com/a/36130467/8102448 +tasks.withType(Test) { + testLogging { + events TestLogEvent.FAILED, + TestLogEvent.PASSED, + TestLogEvent.SKIPPED + exceptionFormat TestExceptionFormat.FULL + showExceptions true + showCauses true + showStackTraces true + + afterSuite { desc, result -> + if (!desc.parent) { + def output = "Results: ${result.resultType} (${result.testCount} tests, ${result.successfulTestCount} passed, ${result.failedTestCount} failed, ${result.skippedTestCount} skipped)" + def startItem = '| ', endItem = ' |' + def repeatLength = startItem.length() + output.length() + endItem.length() + println('\n' + ('-' * repeatLength) + '\n' + startItem + output + endItem + '\n' + ('-' * repeatLength)) + } + } + } +} \ No newline at end of file diff --git a/backend/src/test/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiControllerTest.java b/backend/src/test/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiControllerTest.java index f1d7f98f..fd5520c9 100644 --- a/backend/src/test/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiControllerTest.java +++ b/backend/src/test/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiControllerTest.java @@ -3,13 +3,19 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import java.io.File; +import java.io.InterruptedIOException; +import java.net.SocketTimeoutException; import java.time.Duration; import java.util.List; +import java.util.Map; +import java.util.UUID; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIf; import org.springframework.beans.factory.annotation.Autowired; @@ -19,10 +25,20 @@ import org.springframework.test.context.ActiveProfiles; import ai.dragon.entity.FarmEntity; +import ai.dragon.entity.SiloEntity; +import ai.dragon.enumeration.EmbeddingModelType; +import ai.dragon.enumeration.IngestorLoaderType; import ai.dragon.enumeration.LanguageModelType; +import ai.dragon.enumeration.VectorStoreType; +import ai.dragon.junit.AbstractTest; +import ai.dragon.junit.extension.retry.RetryOnExceptions; import ai.dragon.repository.FarmRepository; -import ai.dragon.test.AbstractTest; +import ai.dragon.repository.SiloRepository; +import ai.dragon.service.IngestorService; import dev.ai4j.openai4j.OpenAiClient; +import dev.ai4j.openai4j.OpenAiHttpException; +import dev.ai4j.openai4j.chat.ChatCompletionRequest; +import dev.ai4j.openai4j.chat.ChatCompletionResponse; import dev.ai4j.openai4j.completion.CompletionRequest; import dev.ai4j.openai4j.completion.CompletionResponse; import dev.langchain4j.model.mistralai.internal.api.MistralAiModelResponse; @@ -34,62 +50,133 @@ public class OpenAiCompatibleV1ApiControllerTest extends AbstractTest { @LocalServerPort private int serverPort; - @Autowired - private FarmRepository farmRepository; - @BeforeAll - static void beforeAll(@Autowired FarmRepository farmRepository) { - farmRepository.deleteAll(); + static void beforeAll(@Autowired FarmRepository farmRepository, + @Autowired SiloRepository siloRepository, + @Autowired IngestorService ingestorService) throws Exception { + cleanUp(farmRepository, siloRepository); + + // OpenAI settings for RaaG + String apiKeySetting = String.format("apiKey=%s", System.getenv("OPENAI_API_KEY")); + String modelNameSetting = "modelName=gpt-4o"; + + // Farm with no silo + FarmEntity farmWithoutSilo = new FarmEntity(); + farmWithoutSilo.setRaagIdentifier("no-silo-raag"); + farmWithoutSilo.setLanguageModel(LanguageModelType.OpenAiModel); + farmWithoutSilo.setLanguageModelSettings(List.of(apiKeySetting, modelNameSetting)); + farmRepository.save(farmWithoutSilo); + + // Resources for the silo + String ragResourcesPath = "src/test/resources/rag_documents/sunspots"; + File ragResources = new File(ragResourcesPath); + String ragResourcesAbsolutePath = ragResources.getAbsolutePath(); + assertNotNull(ragResourcesAbsolutePath); + + // Silo about "Sunspots" + SiloEntity sunspotsSilo = new SiloEntity(); + sunspotsSilo.setUuid(UUID.randomUUID()); + sunspotsSilo.setName("Sunspots Silo"); + sunspotsSilo.setEmbeddingModel(EmbeddingModelType.BgeSmallEnV15QuantizedEmbeddingModel); + sunspotsSilo.setEmbeddingSettings(List.of( + "chunkSize=1000", + "chunkOverlap=100")); + sunspotsSilo.setVectorStore(VectorStoreType.InMemoryEmbeddingStore); + sunspotsSilo.setIngestorLoader(IngestorLoaderType.FileSystem); + sunspotsSilo.setIngestorSettings(List.of( + String.format("paths[]=%s", ragResourcesAbsolutePath), + "recursive=false", + "pathMatcher=glob:**.{pdf,doc,docx,ppt,pptx}")); + siloRepository.save(sunspotsSilo); + + // Launching ingestion of documents inside the Silo + ingestorService.runSiloIngestion(sunspotsSilo, ingestProgress -> { + System.out.println("Ingest progress: " + ingestProgress); + }, ingestLogMessage -> { + System.out.println(ingestLogMessage.getMessage()); + }); + + // Farm with the Sunspots Silo + FarmEntity farmWithSunspotsSilo = new FarmEntity(); + farmWithSunspotsSilo.setRaagIdentifier("sunspots-raag"); + farmWithSunspotsSilo.setLanguageModel(LanguageModelType.OpenAiModel); + farmWithSunspotsSilo.setLanguageModelSettings(List.of(apiKeySetting, modelNameSetting)); + farmWithSunspotsSilo.setSilos(List.of(sunspotsSilo.getUuid())); + farmRepository.save(farmWithSunspotsSilo); + + // Farm with the Sunspots Silo but with Query Rewriting + FarmEntity farmWithSunspotsSiloAndQueryRewriting = new FarmEntity(); + farmWithSunspotsSiloAndQueryRewriting.setRaagIdentifier("sunspots-rewriting-raag"); + farmWithSunspotsSiloAndQueryRewriting.setLanguageModel(LanguageModelType.OpenAiModel); + farmWithSunspotsSiloAndQueryRewriting.setLanguageModelSettings(List.of(apiKeySetting, modelNameSetting)); + farmWithSunspotsSiloAndQueryRewriting.setSilos(List.of(sunspotsSilo.getUuid())); + farmWithSunspotsSiloAndQueryRewriting.setRetrievalAugmentorSettings(List.of("rewriteQuery=true")); + farmRepository.save(farmWithSunspotsSiloAndQueryRewriting); } @AfterAll - static void afterAll(@Autowired FarmRepository farmRepository) { - farmRepository.deleteAll(); + static void afterAll(@Autowired FarmRepository farmRepository, @Autowired SiloRepository siloRepository) { + cleanUp(farmRepository, siloRepository); } - @BeforeEach - void beforeEach(@Autowired FarmRepository farmRepository) { + static void cleanUp(FarmRepository farmRepository, SiloRepository siloRepository) { farmRepository.deleteAll(); + siloRepository.deleteAll(); } - @Test - void listModels() throws Exception { - FarmEntity farm = new FarmEntity(); - farm.setRaagIdentifier("awesome-raag"); - farmRepository.save(farm); + @SuppressWarnings("rawtypes") + private OpenAiClient.Builder createOpenAiClientBuilder() { + return OpenAiClient.builder() + .openAiApiKey("TODO_PUT_KEY_HERE") + .baseUrl(String.format("http://localhost:%d/api/raag/v1/", serverPort)) + .callTimeout(Duration.ofSeconds(10)) + .readTimeout(Duration.ofSeconds(10)) + .writeTimeout(Duration.ofSeconds(10)) + .connectTimeout(Duration.ofSeconds(10)); + } - MistralAiClient client = MistralAiClient.builder() + @SuppressWarnings("rawtypes") + private MistralAiClient.Builder createMistralAiClientBuilder() { + return MistralAiClient.builder() .apiKey("TODO_PUT_KEY_HERE") .baseUrl(String.format("http://localhost:%d/api/raag/v1/", serverPort)) .timeout(Duration.ofSeconds(10)) .logRequests(false) - .logResponses(false) - .build(); + .logResponses(false); + } + + @Test + @RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class }) + void listModels() throws Exception { + MistralAiClient client = createMistralAiClientBuilder().build(); MistralAiModelResponse modelsResponse = client.listModels(); assertNotNull(modelsResponse); - assertEquals(1, modelsResponse.getData().size()); - assertEquals(farm.getRaagIdentifier(), modelsResponse.getData().get(0).getId()); + assertNotEquals(0, modelsResponse.getData().size()); } @Test @EnabledIf("canRunOpenAiRelatedTests") - void testFarmNoSiloOpenAI() { - String apiKeySetting = String.format("apiKey=%s", System.getenv("OPENAI_API_KEY")); - String modelNameSetting = "modelName=gpt-4o"; - - FarmEntity farm = new FarmEntity(); - farm.setRaagIdentifier("dragon-raag"); - farm.setLanguageModel(LanguageModelType.OpenAiModel); - farm.setLanguageModelSettings(List.of(apiKeySetting, modelNameSetting)); - farmRepository.save(farm); - - OpenAiClient client = OpenAiClient.builder() - .openAiApiKey("TODO_PUT_KEY_HERE") - .baseUrl(String.format("http://localhost:%d/api/raag/v1/", serverPort)) + @RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class }) + void testModelDoesntExistOpenAI() { + OpenAiClient client = createOpenAiClientBuilder().build(); + CompletionRequest request = CompletionRequest.builder() + .model("should-not-exist") + .prompt("Just say : 'dRAGon'") .build(); + OpenAiHttpException exception = assertThrows(OpenAiHttpException.class, + () -> client.completion(request).execute()); + assertTrue(exception.code() == 404); + } + + @Test + @EnabledIf("canRunOpenAiRelatedTests") + @RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class }) + void testFarmNoSiloOpenAI() { + OpenAiClient client = createOpenAiClientBuilder().build(); CompletionRequest request = CompletionRequest.builder() - .model("dragon-raag") + .model("no-silo-raag") .prompt("Just say 'HELLO' in lowercased letters.") + .temperature(0.0) .build(); CompletionResponse response = client.completion(request).execute(); assertNotNull(response); @@ -97,4 +184,113 @@ void testFarmNoSiloOpenAI() { assertNotEquals(0, response.choices().size()); assertEquals("hello", response.choices().get(0).text()); } + + @Test + @EnabledIf("canRunOpenAiRelatedTests") + @RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class }) + void testFarmCompletionOpenAI() { + OpenAiClient client = createOpenAiClientBuilder().build(); + CompletionRequest request = CompletionRequest.builder() + .model("sunspots-raag") + .prompt("Who is the author of document 'The Size of the Carrington Event Sunspot Group'? Just reply with the firstname and lastname.") + .stream(false) + .temperature(0.0) + .build(); + CompletionResponse response = client.completion(request).execute(); + assertNotNull(response); + assertNotNull(response.choices()); + assertNotEquals(0, response.choices().size()); + assertEquals("Peter Meadows", response.choices().get(0).text()); + } + + @Test + @EnabledIf("canRunOpenAiRelatedTests") + @RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class }) + void testFarmCompletionStreamOpenAI() { + OpenAiClient client = createOpenAiClientBuilder().build(); + CompletionRequest request = CompletionRequest.builder() + .model("sunspots-raag") + .prompt(""" + Who is the author of document 'Sunspots, unemployment, and recessions, or Can the solar activity cycle shape the business cycle?'? + Just reply with the firstname and lastname. + """) + .stream(true) + .temperature(0.0) + .build(); + CompletionResponse response = client.completion(request).execute(); + assertNotNull(response); + assertNotNull(response.choices()); + assertNotEquals(0, response.choices().size()); + assertEquals("Mikhail Gorbanev", response.choices().get(0).text()); + } + + @Test + @EnabledIf("canRunOpenAiRelatedTests") + @RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class }) + void testFarmChatRewriteQueryOpenAI() { + OpenAiClient client = createOpenAiClientBuilder().build(); + ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder() + .addSystemMessage( + "You are a researcher in solar physics and you provide help to other researchers.") + .addUserMessage( + "Hello, I am looking for the author of the document 'The Size of the Carrington Event Sunspot Group'.") + .addUserMessage( + """ + Can you help me? + * Just use the context of this message and nothing else to reply. + * If the information is not provided, just say 'I do not know!' + * Just reply with the firstname and lastname. + """) + .temperature(0.0); + + for (int i = 0; i <= 1; i++) { + for (int j = 0; j <= 1; j++) { + ChatCompletionRequest request = requestBuilder + .model(i == 0 ? "sunspots-raag" : "sunspots-rewriting-raag") + .stream(j == 1) + .build(); + ChatCompletionResponse response = client.chatCompletion(request).execute(); + assertNotNull(response); + assertNotNull(response.choices()); + assertNotEquals(0, response.choices().size()); + assertEquals(i == 0 ? "I do not know!" : "Peter Meadows", response.content()); + } + } + } + + @Test + @EnabledIf("canRunOpenAiRelatedTests") + @RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class }) + @SuppressWarnings("unchecked") + void testFarmCompletionWithMetadataFilterOpenAI() { + Map.of( + "non_existing_document.pdf", false, + "BAAJournalCarringtonEventPaper_compressed.pdf", true) + .forEach((documentName, expected) -> { + Map customHeaders = Map.of( + "X-RAG-FILTER-METADATA", + String.format("{{#metadataKey('document_name').isIn('%s')}}", documentName)); + OpenAiClient client = createOpenAiClientBuilder() + .customHeaders(customHeaders) + .build(); + CompletionRequest request = CompletionRequest.builder() + .model("sunspots-raag") + .prompt(""" + Who is the author of document 'The Size of the Carrington Event Sunspot Group'? + Just reply with the firstname and lastname. + """) + .stream(false) + .temperature(0.0) + .build(); + CompletionResponse response = client.completion(request).execute(); + assertNotNull(response); + assertNotNull(response.choices()); + assertNotEquals(0, response.choices().size()); + if (expected) { + assertEquals("Peter Meadows", response.choices().get(0).text()); + } else { + assertNotEquals("Peter Meadows", response.choices().get(0).text()); + } + }); + } } diff --git a/backend/src/test/java/ai/dragon/test/AbstractTest.java b/backend/src/test/java/ai/dragon/junit/AbstractTest.java similarity index 88% rename from backend/src/test/java/ai/dragon/test/AbstractTest.java rename to backend/src/test/java/ai/dragon/junit/AbstractTest.java index 38800552..5be4ebd2 100644 --- a/backend/src/test/java/ai/dragon/test/AbstractTest.java +++ b/backend/src/test/java/ai/dragon/junit/AbstractTest.java @@ -1,4 +1,4 @@ -package ai.dragon.test; +package ai.dragon.junit; public abstract class AbstractTest { protected boolean canRunOpenAiRelatedTests() { diff --git a/backend/src/test/java/ai/dragon/junit/extension/retry/RetryOnExceptions.java b/backend/src/test/java/ai/dragon/junit/extension/retry/RetryOnExceptions.java new file mode 100644 index 00000000..65a7ea76 --- /dev/null +++ b/backend/src/test/java/ai/dragon/junit/extension/retry/RetryOnExceptions.java @@ -0,0 +1,16 @@ +package ai.dragon.junit.extension.retry; + +import org.junit.jupiter.api.extension.ExtendWith; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@ExtendWith(RetryOnExceptionsExtension.class) +public @interface RetryOnExceptions { + int value() default 3; // Default retry count + Class[] onExceptions() default { Throwable.class }; // Exceptions to retry on +} diff --git a/backend/src/test/java/ai/dragon/junit/extension/retry/RetryOnExceptionsExtension.java b/backend/src/test/java/ai/dragon/junit/extension/retry/RetryOnExceptionsExtension.java new file mode 100644 index 00000000..94a886e9 --- /dev/null +++ b/backend/src/test/java/ai/dragon/junit/extension/retry/RetryOnExceptionsExtension.java @@ -0,0 +1,51 @@ +package ai.dragon.junit.extension.retry; + +import java.lang.reflect.Method; + +import org.junit.jupiter.api.extension.BeforeEachCallback; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.api.extension.ExtensionContext.Store; +import org.junit.jupiter.api.extension.TestExecutionExceptionHandler; + +public class RetryOnExceptionsExtension implements TestExecutionExceptionHandler, BeforeEachCallback { + + @Override + public void handleTestExecutionException(ExtensionContext context, Throwable throwable) throws Throwable { + Method testMethod = context.getRequiredTestMethod(); + RetryOnExceptions retry = testMethod.getAnnotation(RetryOnExceptions.class); + + if (retry != null) { + int maxRetries = retry.value(); + Class[] retryOnExceptions = retry.onExceptions(); + Store store = context.getStore(ExtensionContext.Namespace.create(testMethod)); + + int currentRetries = store.getOrDefault("retries", Integer.class, 0); + + boolean shouldRetry = false; + for (Class retryOnException : retryOnExceptions) { + if (retryOnException.isInstance(throwable)) { + shouldRetry = true; + break; + } + } + + if (shouldRetry && currentRetries < maxRetries) { + store.put("retries", currentRetries + 1); + context.getRequiredTestMethod().invoke(context.getRequiredTestInstance()); + } else { + throw throwable; + } + } else { + throw throwable; + } + } + + @Override + public void beforeEach(ExtensionContext context) throws Exception { + Method testMethod = context.getRequiredTestMethod(); + if (testMethod.isAnnotationPresent(RetryOnExceptions.class)) { + Store store = context.getStore(ExtensionContext.Namespace.create(testMethod)); + store.put("retries", 0); // Reset retry count before each test + } + } +}