diff --git a/README.md b/README.md index 6b28e0ba..37ec91a9 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ +![Build dRAGon Project](https://github.com/dragon-okinawa/dragon/actions/workflows/build.yml/badge.svg?branch=main) [![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=dRAGon-Okinawa_dRAGon&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=dRAGon-Okinawa_dRAGon) [![Vulnerabilities](https://sonarcloud.io/api/project_badges/measure?project=dRAGon-Okinawa_dRAGon&metric=vulnerabilities)](https://sonarcloud.io/summary/new_code?id=dRAGon-Okinawa_dRAGon) [![Bugs](https://sonarcloud.io/api/project_badges/measure?project=dRAGon-Okinawa_dRAGon&metric=bugs)](https://sonarcloud.io/summary/new_code?id=dRAGon-Okinawa_dRAGon) diff --git a/backend/src/test/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiControllerTest.java b/backend/src/test/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiControllerTest.java index bb990bdd..fd5520c9 100644 --- a/backend/src/test/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiControllerTest.java +++ b/backend/src/test/java/ai/dragon/controller/api/raag/OpenAiCompatibleV1ApiControllerTest.java @@ -6,8 +6,13 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.io.File; +import java.io.InterruptedIOException; +import java.net.SocketTimeoutException; import java.time.Duration; import java.util.List; +import java.util.Map; +import java.util.UUID; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -20,12 +25,20 @@ import org.springframework.test.context.ActiveProfiles; import ai.dragon.entity.FarmEntity; +import ai.dragon.entity.SiloEntity; +import ai.dragon.enumeration.EmbeddingModelType; +import ai.dragon.enumeration.IngestorLoaderType; import ai.dragon.enumeration.LanguageModelType; +import ai.dragon.enumeration.VectorStoreType; +import ai.dragon.junit.AbstractTest; +import ai.dragon.junit.extension.retry.RetryOnExceptions; import ai.dragon.repository.FarmRepository; import ai.dragon.repository.SiloRepository; -import ai.dragon.test.AbstractTest; +import ai.dragon.service.IngestorService; import dev.ai4j.openai4j.OpenAiClient; import dev.ai4j.openai4j.OpenAiHttpException; +import dev.ai4j.openai4j.chat.ChatCompletionRequest; +import dev.ai4j.openai4j.chat.ChatCompletionResponse; import dev.ai4j.openai4j.completion.CompletionRequest; import dev.ai4j.openai4j.completion.CompletionResponse; import dev.langchain4j.model.mistralai.internal.api.MistralAiModelResponse; @@ -38,17 +51,67 @@ public class OpenAiCompatibleV1ApiControllerTest extends AbstractTest { private int serverPort; @BeforeAll - static void beforeAll(@Autowired FarmRepository farmRepository, @Autowired SiloRepository siloRepository) { + static void beforeAll(@Autowired FarmRepository farmRepository, + @Autowired SiloRepository siloRepository, + @Autowired IngestorService ingestorService) throws Exception { cleanUp(farmRepository, siloRepository); + // OpenAI settings for RaaG String apiKeySetting = String.format("apiKey=%s", System.getenv("OPENAI_API_KEY")); String modelNameSetting = "modelName=gpt-4o"; + // Farm with no silo FarmEntity farmWithoutSilo = new FarmEntity(); farmWithoutSilo.setRaagIdentifier("no-silo-raag"); farmWithoutSilo.setLanguageModel(LanguageModelType.OpenAiModel); farmWithoutSilo.setLanguageModelSettings(List.of(apiKeySetting, modelNameSetting)); farmRepository.save(farmWithoutSilo); + + // Resources for the silo + String ragResourcesPath = "src/test/resources/rag_documents/sunspots"; + File ragResources = new File(ragResourcesPath); + String ragResourcesAbsolutePath = ragResources.getAbsolutePath(); + assertNotNull(ragResourcesAbsolutePath); + + // Silo about "Sunspots" + SiloEntity sunspotsSilo = new SiloEntity(); + sunspotsSilo.setUuid(UUID.randomUUID()); + sunspotsSilo.setName("Sunspots Silo"); + sunspotsSilo.setEmbeddingModel(EmbeddingModelType.BgeSmallEnV15QuantizedEmbeddingModel); + sunspotsSilo.setEmbeddingSettings(List.of( + "chunkSize=1000", + "chunkOverlap=100")); + sunspotsSilo.setVectorStore(VectorStoreType.InMemoryEmbeddingStore); + sunspotsSilo.setIngestorLoader(IngestorLoaderType.FileSystem); + sunspotsSilo.setIngestorSettings(List.of( + String.format("paths[]=%s", ragResourcesAbsolutePath), + "recursive=false", + "pathMatcher=glob:**.{pdf,doc,docx,ppt,pptx}")); + siloRepository.save(sunspotsSilo); + + // Launching ingestion of documents inside the Silo + ingestorService.runSiloIngestion(sunspotsSilo, ingestProgress -> { + System.out.println("Ingest progress: " + ingestProgress); + }, ingestLogMessage -> { + System.out.println(ingestLogMessage.getMessage()); + }); + + // Farm with the Sunspots Silo + FarmEntity farmWithSunspotsSilo = new FarmEntity(); + farmWithSunspotsSilo.setRaagIdentifier("sunspots-raag"); + farmWithSunspotsSilo.setLanguageModel(LanguageModelType.OpenAiModel); + farmWithSunspotsSilo.setLanguageModelSettings(List.of(apiKeySetting, modelNameSetting)); + farmWithSunspotsSilo.setSilos(List.of(sunspotsSilo.getUuid())); + farmRepository.save(farmWithSunspotsSilo); + + // Farm with the Sunspots Silo but with Query Rewriting + FarmEntity farmWithSunspotsSiloAndQueryRewriting = new FarmEntity(); + farmWithSunspotsSiloAndQueryRewriting.setRaagIdentifier("sunspots-rewriting-raag"); + farmWithSunspotsSiloAndQueryRewriting.setLanguageModel(LanguageModelType.OpenAiModel); + farmWithSunspotsSiloAndQueryRewriting.setLanguageModelSettings(List.of(apiKeySetting, modelNameSetting)); + farmWithSunspotsSiloAndQueryRewriting.setSilos(List.of(sunspotsSilo.getUuid())); + farmWithSunspotsSiloAndQueryRewriting.setRetrievalAugmentorSettings(List.of("rewriteQuery=true")); + farmRepository.save(farmWithSunspotsSiloAndQueryRewriting); } @AfterAll @@ -61,15 +124,31 @@ static void cleanUp(FarmRepository farmRepository, SiloRepository siloRepository siloRepository.deleteAll(); } - @Test - void listModels() throws Exception { - MistralAiClient client = MistralAiClient.builder() + @SuppressWarnings("rawtypes") + private OpenAiClient.Builder createOpenAiClientBuilder() { + return OpenAiClient.builder() + .openAiApiKey("TODO_PUT_KEY_HERE") + .baseUrl(String.format("http://localhost:%d/api/raag/v1/", serverPort)) + .callTimeout(Duration.ofSeconds(10)) + .readTimeout(Duration.ofSeconds(10)) + .writeTimeout(Duration.ofSeconds(10)) + .connectTimeout(Duration.ofSeconds(10)); + } + + @SuppressWarnings("rawtypes") + private MistralAiClient.Builder createMistralAiClientBuilder() { + return MistralAiClient.builder() .apiKey("TODO_PUT_KEY_HERE") .baseUrl(String.format("http://localhost:%d/api/raag/v1/", serverPort)) .timeout(Duration.ofSeconds(10)) .logRequests(false) - .logResponses(false) - .build(); + .logResponses(false); + } + + @Test + @RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class }) + void listModels() throws Exception { + MistralAiClient client = createMistralAiClientBuilder().build(); MistralAiModelResponse modelsResponse = client.listModels(); assertNotNull(modelsResponse); assertNotEquals(0, modelsResponse.getData().size()); @@ -77,11 +156,9 @@ void listModels() throws Exception { @Test @EnabledIf("canRunOpenAiRelatedTests") + @RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class }) void testModelDoesntExistOpenAI() { - OpenAiClient client = OpenAiClient.builder() - .openAiApiKey("TODO_PUT_KEY_HERE") - .baseUrl(String.format("http://localhost:%d/api/raag/v1/", serverPort)) - .build(); + OpenAiClient client = createOpenAiClientBuilder().build(); CompletionRequest request = CompletionRequest.builder() .model("should-not-exist") .prompt("Just say : 'dRAGon'") @@ -93,14 +170,13 @@ void testModelDoesntExistOpenAI() { @Test @EnabledIf("canRunOpenAiRelatedTests") + @RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class }) void testFarmNoSiloOpenAI() { - OpenAiClient client = OpenAiClient.builder() - .openAiApiKey("TODO_PUT_KEY_HERE") - .baseUrl(String.format("http://localhost:%d/api/raag/v1/", serverPort)) - .build(); + OpenAiClient client = createOpenAiClientBuilder().build(); CompletionRequest request = CompletionRequest.builder() .model("no-silo-raag") .prompt("Just say 'HELLO' in lowercased letters.") + .temperature(0.0) .build(); CompletionResponse response = client.completion(request).execute(); assertNotNull(response); @@ -111,19 +187,110 @@ void testFarmNoSiloOpenAI() { @Test @EnabledIf("canRunOpenAiRelatedTests") - void testFarmWithSilosOpenAI() { - OpenAiClient client = OpenAiClient.builder() - .openAiApiKey("TODO_PUT_KEY_HERE") - .baseUrl(String.format("http://localhost:%d/api/raag/v1/", serverPort)) + @RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class }) + void testFarmCompletionOpenAI() { + OpenAiClient client = createOpenAiClientBuilder().build(); + CompletionRequest request = CompletionRequest.builder() + .model("sunspots-raag") + .prompt("Who is the author of document 'The Size of the Carrington Event Sunspot Group'? Just reply with the firstname and lastname.") + .stream(false) + .temperature(0.0) .build(); + CompletionResponse response = client.completion(request).execute(); + assertNotNull(response); + assertNotNull(response.choices()); + assertNotEquals(0, response.choices().size()); + assertEquals("Peter Meadows", response.choices().get(0).text()); + } + + @Test + @EnabledIf("canRunOpenAiRelatedTests") + @RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class }) + void testFarmCompletionStreamOpenAI() { + OpenAiClient client = createOpenAiClientBuilder().build(); CompletionRequest request = CompletionRequest.builder() - .model("no-silo-raag") - .prompt("Just say 'HELLO' in lowercased letters.") + .model("sunspots-raag") + .prompt(""" + Who is the author of document 'Sunspots, unemployment, and recessions, or Can the solar activity cycle shape the business cycle?'? + Just reply with the firstname and lastname. + """) + .stream(true) + .temperature(0.0) .build(); CompletionResponse response = client.completion(request).execute(); assertNotNull(response); assertNotNull(response.choices()); assertNotEquals(0, response.choices().size()); - assertEquals("hello", response.choices().get(0).text()); + assertEquals("Mikhail Gorbanev", response.choices().get(0).text()); + } + + @Test + @EnabledIf("canRunOpenAiRelatedTests") + @RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class }) + void testFarmChatRewriteQueryOpenAI() { + OpenAiClient client = createOpenAiClientBuilder().build(); + ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder() + .addSystemMessage( + "You are a researcher in solar physics and you provide help to other researchers.") + .addUserMessage( + "Hello, I am looking for the author of the document 'The Size of the Carrington Event Sunspot Group'.") + .addUserMessage( + """ + Can you help me? + * Just use the context of this message and nothing else to reply. + * If the information is not provided, just say 'I do not know!' + * Just reply with the firstname and lastname. + """) + .temperature(0.0); + + for (int i = 0; i <= 1; i++) { + for (int j = 0; j <= 1; j++) { + ChatCompletionRequest request = requestBuilder + .model(i == 0 ? "sunspots-raag" : "sunspots-rewriting-raag") + .stream(j == 1) + .build(); + ChatCompletionResponse response = client.chatCompletion(request).execute(); + assertNotNull(response); + assertNotNull(response.choices()); + assertNotEquals(0, response.choices().size()); + assertEquals(i == 0 ? "I do not know!" : "Peter Meadows", response.content()); + } + } + } + + @Test + @EnabledIf("canRunOpenAiRelatedTests") + @RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class }) + @SuppressWarnings("unchecked") + void testFarmCompletionWithMetadataFilterOpenAI() { + Map.of( + "non_existing_document.pdf", false, + "BAAJournalCarringtonEventPaper_compressed.pdf", true) + .forEach((documentName, expected) -> { + Map customHeaders = Map.of( + "X-RAG-FILTER-METADATA", + String.format("{{#metadataKey('document_name').isIn('%s')}}", documentName)); + OpenAiClient client = createOpenAiClientBuilder() + .customHeaders(customHeaders) + .build(); + CompletionRequest request = CompletionRequest.builder() + .model("sunspots-raag") + .prompt(""" + Who is the author of document 'The Size of the Carrington Event Sunspot Group'? + Just reply with the firstname and lastname. + """) + .stream(false) + .temperature(0.0) + .build(); + CompletionResponse response = client.completion(request).execute(); + assertNotNull(response); + assertNotNull(response.choices()); + assertNotEquals(0, response.choices().size()); + if (expected) { + assertEquals("Peter Meadows", response.choices().get(0).text()); + } else { + assertNotEquals("Peter Meadows", response.choices().get(0).text()); + } + }); } } diff --git a/backend/src/test/java/ai/dragon/test/AbstractTest.java b/backend/src/test/java/ai/dragon/junit/AbstractTest.java similarity index 88% rename from backend/src/test/java/ai/dragon/test/AbstractTest.java rename to backend/src/test/java/ai/dragon/junit/AbstractTest.java index 38800552..5be4ebd2 100644 --- a/backend/src/test/java/ai/dragon/test/AbstractTest.java +++ b/backend/src/test/java/ai/dragon/junit/AbstractTest.java @@ -1,4 +1,4 @@ -package ai.dragon.test; +package ai.dragon.junit; public abstract class AbstractTest { protected boolean canRunOpenAiRelatedTests() { diff --git a/backend/src/test/java/ai/dragon/junit/extension/retry/RetryOnExceptions.java b/backend/src/test/java/ai/dragon/junit/extension/retry/RetryOnExceptions.java new file mode 100644 index 00000000..65a7ea76 --- /dev/null +++ b/backend/src/test/java/ai/dragon/junit/extension/retry/RetryOnExceptions.java @@ -0,0 +1,16 @@ +package ai.dragon.junit.extension.retry; + +import org.junit.jupiter.api.extension.ExtendWith; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@ExtendWith(RetryOnExceptionsExtension.class) +public @interface RetryOnExceptions { + int value() default 3; // Default retry count + Class[] onExceptions() default { Throwable.class }; // Exceptions to retry on +} diff --git a/backend/src/test/java/ai/dragon/junit/extension/retry/RetryOnExceptionsExtension.java b/backend/src/test/java/ai/dragon/junit/extension/retry/RetryOnExceptionsExtension.java new file mode 100644 index 00000000..94a886e9 --- /dev/null +++ b/backend/src/test/java/ai/dragon/junit/extension/retry/RetryOnExceptionsExtension.java @@ -0,0 +1,51 @@ +package ai.dragon.junit.extension.retry; + +import java.lang.reflect.Method; + +import org.junit.jupiter.api.extension.BeforeEachCallback; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.api.extension.ExtensionContext.Store; +import org.junit.jupiter.api.extension.TestExecutionExceptionHandler; + +public class RetryOnExceptionsExtension implements TestExecutionExceptionHandler, BeforeEachCallback { + + @Override + public void handleTestExecutionException(ExtensionContext context, Throwable throwable) throws Throwable { + Method testMethod = context.getRequiredTestMethod(); + RetryOnExceptions retry = testMethod.getAnnotation(RetryOnExceptions.class); + + if (retry != null) { + int maxRetries = retry.value(); + Class[] retryOnExceptions = retry.onExceptions(); + Store store = context.getStore(ExtensionContext.Namespace.create(testMethod)); + + int currentRetries = store.getOrDefault("retries", Integer.class, 0); + + boolean shouldRetry = false; + for (Class retryOnException : retryOnExceptions) { + if (retryOnException.isInstance(throwable)) { + shouldRetry = true; + break; + } + } + + if (shouldRetry && currentRetries < maxRetries) { + store.put("retries", currentRetries + 1); + context.getRequiredTestMethod().invoke(context.getRequiredTestInstance()); + } else { + throw throwable; + } + } else { + throw throwable; + } + } + + @Override + public void beforeEach(ExtensionContext context) throws Exception { + Method testMethod = context.getRequiredTestMethod(); + if (testMethod.isAnnotationPresent(RetryOnExceptions.class)) { + Store store = context.getStore(ExtensionContext.Namespace.create(testMethod)); + store.put("retries", 0); // Reset retry count before each test + } + } +}