Skip to content

Commit

Permalink
Add build status badge to README and update OpenAiCompatibleV1ApiCont…
Browse files Browse the repository at this point in the history
…rollerTest (#184)

* Fail testFarmWithSunspotsSiloOpenAI

* chore: Update OpenAiCompatibleV1ApiControllerTest with new models and tests

* chore: Update OpenAiCompatibleV1ApiControllerTest with new models and tests

* chore: Update README.md with build status badge

* chore: Remove TODO comment

* chore: Add RetryOnExceptionsExtension for test execution exception handling

* chore: Add RetryOnExceptionsExtension for test execution exception handling

* Revert back to "gpt-4o" (instead of "gpt-4o-mini") for OpenAiCompatibleV1ApiControllerTest due too many IOException

---------

Co-authored-by: Arnaud Mengus <[email protected]>
  • Loading branch information
amengus87 and isontheline authored Jul 26, 2024
1 parent 4f6d576 commit dfa4b07
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 23 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
![Build dRAGon Project](https://github.com/dragon-okinawa/dragon/actions/workflows/build.yml/badge.svg?branch=main)
[![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=dRAGon-Okinawa_dRAGon&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=dRAGon-Okinawa_dRAGon)
[![Vulnerabilities](https://sonarcloud.io/api/project_badges/measure?project=dRAGon-Okinawa_dRAGon&metric=vulnerabilities)](https://sonarcloud.io/summary/new_code?id=dRAGon-Okinawa_dRAGon)
[![Bugs](https://sonarcloud.io/api/project_badges/measure?project=dRAGon-Okinawa_dRAGon&metric=bugs)](https://sonarcloud.io/summary/new_code?id=dRAGon-Okinawa_dRAGon)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.File;
import java.io.InterruptedIOException;
import java.net.SocketTimeoutException;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.UUID;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
Expand All @@ -20,12 +25,20 @@
import org.springframework.test.context.ActiveProfiles;

import ai.dragon.entity.FarmEntity;
import ai.dragon.entity.SiloEntity;
import ai.dragon.enumeration.EmbeddingModelType;
import ai.dragon.enumeration.IngestorLoaderType;
import ai.dragon.enumeration.LanguageModelType;
import ai.dragon.enumeration.VectorStoreType;
import ai.dragon.junit.AbstractTest;
import ai.dragon.junit.extension.retry.RetryOnExceptions;
import ai.dragon.repository.FarmRepository;
import ai.dragon.repository.SiloRepository;
import ai.dragon.test.AbstractTest;
import ai.dragon.service.IngestorService;
import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.completion.CompletionRequest;
import dev.ai4j.openai4j.completion.CompletionResponse;
import dev.langchain4j.model.mistralai.internal.api.MistralAiModelResponse;
Expand All @@ -38,17 +51,67 @@ public class OpenAiCompatibleV1ApiControllerTest extends AbstractTest {
private int serverPort;

@BeforeAll
static void beforeAll(@Autowired FarmRepository farmRepository, @Autowired SiloRepository siloRepository) {
static void beforeAll(@Autowired FarmRepository farmRepository,
@Autowired SiloRepository siloRepository,
@Autowired IngestorService ingestorService) throws Exception {
cleanUp(farmRepository, siloRepository);

// OpenAI settings for RaaG
String apiKeySetting = String.format("apiKey=%s", System.getenv("OPENAI_API_KEY"));
String modelNameSetting = "modelName=gpt-4o";

// Farm with no silo
FarmEntity farmWithoutSilo = new FarmEntity();
farmWithoutSilo.setRaagIdentifier("no-silo-raag");
farmWithoutSilo.setLanguageModel(LanguageModelType.OpenAiModel);
farmWithoutSilo.setLanguageModelSettings(List.of(apiKeySetting, modelNameSetting));
farmRepository.save(farmWithoutSilo);

// Resources for the silo
String ragResourcesPath = "src/test/resources/rag_documents/sunspots";
File ragResources = new File(ragResourcesPath);
String ragResourcesAbsolutePath = ragResources.getAbsolutePath();
assertNotNull(ragResourcesAbsolutePath);

// Silo about "Sunspots"
SiloEntity sunspotsSilo = new SiloEntity();
sunspotsSilo.setUuid(UUID.randomUUID());
sunspotsSilo.setName("Sunspots Silo");
sunspotsSilo.setEmbeddingModel(EmbeddingModelType.BgeSmallEnV15QuantizedEmbeddingModel);
sunspotsSilo.setEmbeddingSettings(List.of(
"chunkSize=1000",
"chunkOverlap=100"));
sunspotsSilo.setVectorStore(VectorStoreType.InMemoryEmbeddingStore);
sunspotsSilo.setIngestorLoader(IngestorLoaderType.FileSystem);
sunspotsSilo.setIngestorSettings(List.of(
String.format("paths[]=%s", ragResourcesAbsolutePath),
"recursive=false",
"pathMatcher=glob:**.{pdf,doc,docx,ppt,pptx}"));
siloRepository.save(sunspotsSilo);

// Launching ingestion of documents inside the Silo
ingestorService.runSiloIngestion(sunspotsSilo, ingestProgress -> {
System.out.println("Ingest progress: " + ingestProgress);
}, ingestLogMessage -> {
System.out.println(ingestLogMessage.getMessage());
});

// Farm with the Sunspots Silo
FarmEntity farmWithSunspotsSilo = new FarmEntity();
farmWithSunspotsSilo.setRaagIdentifier("sunspots-raag");
farmWithSunspotsSilo.setLanguageModel(LanguageModelType.OpenAiModel);
farmWithSunspotsSilo.setLanguageModelSettings(List.of(apiKeySetting, modelNameSetting));
farmWithSunspotsSilo.setSilos(List.of(sunspotsSilo.getUuid()));
farmRepository.save(farmWithSunspotsSilo);

// Farm with the Sunspots Silo but with Query Rewriting
FarmEntity farmWithSunspotsSiloAndQueryRewriting = new FarmEntity();
farmWithSunspotsSiloAndQueryRewriting.setRaagIdentifier("sunspots-rewriting-raag");
farmWithSunspotsSiloAndQueryRewriting.setLanguageModel(LanguageModelType.OpenAiModel);
farmWithSunspotsSiloAndQueryRewriting.setLanguageModelSettings(List.of(apiKeySetting, modelNameSetting));
farmWithSunspotsSiloAndQueryRewriting.setSilos(List.of(sunspotsSilo.getUuid()));
farmWithSunspotsSiloAndQueryRewriting.setRetrievalAugmentorSettings(List.of("rewriteQuery=true"));
farmRepository.save(farmWithSunspotsSiloAndQueryRewriting);
}

@AfterAll
Expand All @@ -61,27 +124,41 @@ static void cleanUp(FarmRepository farmRepository, SiloRepository siloRepository
siloRepository.deleteAll();
}

@Test
void listModels() throws Exception {
MistralAiClient client = MistralAiClient.builder()
@SuppressWarnings("rawtypes")
private OpenAiClient.Builder createOpenAiClientBuilder() {
return OpenAiClient.builder()
.openAiApiKey("TODO_PUT_KEY_HERE")
.baseUrl(String.format("http://localhost:%d/api/raag/v1/", serverPort))
.callTimeout(Duration.ofSeconds(10))
.readTimeout(Duration.ofSeconds(10))
.writeTimeout(Duration.ofSeconds(10))
.connectTimeout(Duration.ofSeconds(10));
}

@SuppressWarnings("rawtypes")
private MistralAiClient.Builder createMistralAiClientBuilder() {
return MistralAiClient.builder()
.apiKey("TODO_PUT_KEY_HERE")
.baseUrl(String.format("http://localhost:%d/api/raag/v1/", serverPort))
.timeout(Duration.ofSeconds(10))
.logRequests(false)
.logResponses(false)
.build();
.logResponses(false);
}

@Test
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void listModels() throws Exception {
MistralAiClient client = createMistralAiClientBuilder().build();
MistralAiModelResponse modelsResponse = client.listModels();
assertNotNull(modelsResponse);
assertNotEquals(0, modelsResponse.getData().size());
}

@Test
@EnabledIf("canRunOpenAiRelatedTests")
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void testModelDoesntExistOpenAI() {
OpenAiClient client = OpenAiClient.builder()
.openAiApiKey("TODO_PUT_KEY_HERE")
.baseUrl(String.format("http://localhost:%d/api/raag/v1/", serverPort))
.build();
OpenAiClient client = createOpenAiClientBuilder().build();
CompletionRequest request = CompletionRequest.builder()
.model("should-not-exist")
.prompt("Just say : 'dRAGon'")
Expand All @@ -93,14 +170,13 @@ void testModelDoesntExistOpenAI() {

@Test
@EnabledIf("canRunOpenAiRelatedTests")
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void testFarmNoSiloOpenAI() {
OpenAiClient client = OpenAiClient.builder()
.openAiApiKey("TODO_PUT_KEY_HERE")
.baseUrl(String.format("http://localhost:%d/api/raag/v1/", serverPort))
.build();
OpenAiClient client = createOpenAiClientBuilder().build();
CompletionRequest request = CompletionRequest.builder()
.model("no-silo-raag")
.prompt("Just say 'HELLO' in lowercased letters.")
.temperature(0.0)
.build();
CompletionResponse response = client.completion(request).execute();
assertNotNull(response);
Expand All @@ -111,19 +187,110 @@ void testFarmNoSiloOpenAI() {

@Test
@EnabledIf("canRunOpenAiRelatedTests")
void testFarmWithSilosOpenAI() {
OpenAiClient client = OpenAiClient.builder()
.openAiApiKey("TODO_PUT_KEY_HERE")
.baseUrl(String.format("http://localhost:%d/api/raag/v1/", serverPort))
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void testFarmCompletionOpenAI() {
OpenAiClient client = createOpenAiClientBuilder().build();
CompletionRequest request = CompletionRequest.builder()
.model("sunspots-raag")
.prompt("Who is the author of document 'The Size of the Carrington Event Sunspot Group'? Just reply with the firstname and lastname.")
.stream(false)
.temperature(0.0)
.build();
CompletionResponse response = client.completion(request).execute();
assertNotNull(response);
assertNotNull(response.choices());
assertNotEquals(0, response.choices().size());
assertEquals("Peter Meadows", response.choices().get(0).text());
}

@Test
@EnabledIf("canRunOpenAiRelatedTests")
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void testFarmCompletionStreamOpenAI() {
OpenAiClient client = createOpenAiClientBuilder().build();
CompletionRequest request = CompletionRequest.builder()
.model("no-silo-raag")
.prompt("Just say 'HELLO' in lowercased letters.")
.model("sunspots-raag")
.prompt("""
Who is the author of document 'Sunspots, unemployment, and recessions, or Can the solar activity cycle shape the business cycle?'?
Just reply with the firstname and lastname.
""")
.stream(true)
.temperature(0.0)
.build();
CompletionResponse response = client.completion(request).execute();
assertNotNull(response);
assertNotNull(response.choices());
assertNotEquals(0, response.choices().size());
assertEquals("hello", response.choices().get(0).text());
assertEquals("Mikhail Gorbanev", response.choices().get(0).text());
}

@Test
@EnabledIf("canRunOpenAiRelatedTests")
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void testFarmChatRewriteQueryOpenAI() {
OpenAiClient client = createOpenAiClientBuilder().build();
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.addSystemMessage(
"You are a researcher in solar physics and you provide help to other researchers.")
.addUserMessage(
"Hello, I am looking for the author of the document 'The Size of the Carrington Event Sunspot Group'.")
.addUserMessage(
"""
Can you help me?
* Just use the context of this message and nothing else to reply.
* If the information is not provided, just say 'I do not know!'
* Just reply with the firstname and lastname.
""")
.temperature(0.0);

for (int i = 0; i <= 1; i++) {
for (int j = 0; j <= 1; j++) {
ChatCompletionRequest request = requestBuilder
.model(i == 0 ? "sunspots-raag" : "sunspots-rewriting-raag")
.stream(j == 1)
.build();
ChatCompletionResponse response = client.chatCompletion(request).execute();
assertNotNull(response);
assertNotNull(response.choices());
assertNotEquals(0, response.choices().size());
assertEquals(i == 0 ? "I do not know!" : "Peter Meadows", response.content());
}
}
}

@Test
@EnabledIf("canRunOpenAiRelatedTests")
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
@SuppressWarnings("unchecked")
void testFarmCompletionWithMetadataFilterOpenAI() {
Map.of(
"non_existing_document.pdf", false,
"BAAJournalCarringtonEventPaper_compressed.pdf", true)
.forEach((documentName, expected) -> {
Map<String, String> customHeaders = Map.of(
"X-RAG-FILTER-METADATA",
String.format("{{#metadataKey('document_name').isIn('%s')}}", documentName));
OpenAiClient client = createOpenAiClientBuilder()
.customHeaders(customHeaders)
.build();
CompletionRequest request = CompletionRequest.builder()
.model("sunspots-raag")
.prompt("""
Who is the author of document 'The Size of the Carrington Event Sunspot Group'?
Just reply with the firstname and lastname.
""")
.stream(false)
.temperature(0.0)
.build();
CompletionResponse response = client.completion(request).execute();
assertNotNull(response);
assertNotNull(response.choices());
assertNotEquals(0, response.choices().size());
if (expected) {
assertEquals("Peter Meadows", response.choices().get(0).text());
} else {
assertNotEquals("Peter Meadows", response.choices().get(0).text());
}
});
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ai.dragon.test;
package ai.dragon.junit;

public abstract class AbstractTest {
protected boolean canRunOpenAiRelatedTests() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package ai.dragon.junit.extension.retry;

import org.junit.jupiter.api.extension.ExtendWith;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@ExtendWith(RetryOnExceptionsExtension.class)
public @interface RetryOnExceptions {
int value() default 3; // Default retry count
Class<? extends Throwable>[] onExceptions() default { Throwable.class }; // Exceptions to retry on
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package ai.dragon.junit.extension.retry;

import java.lang.reflect.Method;

import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ExtensionContext.Store;
import org.junit.jupiter.api.extension.TestExecutionExceptionHandler;

public class RetryOnExceptionsExtension implements TestExecutionExceptionHandler, BeforeEachCallback {

@Override
public void handleTestExecutionException(ExtensionContext context, Throwable throwable) throws Throwable {
Method testMethod = context.getRequiredTestMethod();
RetryOnExceptions retry = testMethod.getAnnotation(RetryOnExceptions.class);

if (retry != null) {
int maxRetries = retry.value();
Class<? extends Throwable>[] retryOnExceptions = retry.onExceptions();
Store store = context.getStore(ExtensionContext.Namespace.create(testMethod));

int currentRetries = store.getOrDefault("retries", Integer.class, 0);

boolean shouldRetry = false;
for (Class<? extends Throwable> retryOnException : retryOnExceptions) {
if (retryOnException.isInstance(throwable)) {
shouldRetry = true;
break;
}
}

if (shouldRetry && currentRetries < maxRetries) {
store.put("retries", currentRetries + 1);
context.getRequiredTestMethod().invoke(context.getRequiredTestInstance());
} else {
throw throwable;
}
} else {
throw throwable;
}
}

@Override
public void beforeEach(ExtensionContext context) throws Exception {
Method testMethod = context.getRequiredTestMethod();
if (testMethod.isAnnotationPresent(RetryOnExceptions.class)) {
Store store = context.getStore(ExtensionContext.Namespace.create(testMethod));
store.put("retries", 0); // Reset retry count before each test
}
}
}

0 comments on commit dfa4b07

Please sign in to comment.