Skip to content

Commit

Permalink
chore: Add RetryOnExceptionsExtension for test execution exception ha…
Browse files Browse the repository at this point in the history
…ndling
  • Loading branch information
amengus87 committed Jul 26, 2024
1 parent 416dc88 commit 556d996
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.File;
import java.io.InterruptedIOException;
import java.net.SocketTimeoutException;
import java.time.Duration;
import java.util.List;
import java.util.Map;
Expand All @@ -32,6 +34,7 @@
import ai.dragon.repository.SiloRepository;
import ai.dragon.service.IngestorService;
import ai.dragon.test.AbstractTest;
import ai.dragon.test.junit.extension.retry.RetryOnExceptions;
import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
Expand Down Expand Up @@ -122,6 +125,7 @@ static void cleanUp(FarmRepository farmRepository, SiloRepository siloRepository
}

@Test
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void listModels() throws Exception {
MistralAiClient client = MistralAiClient.builder()
.apiKey("TODO_PUT_KEY_HERE")
Expand All @@ -137,6 +141,7 @@ void listModels() throws Exception {

@Test
@EnabledIf("canRunOpenAiRelatedTests")
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void testModelDoesntExistOpenAI() {
OpenAiClient client = OpenAiClient.builder()
.openAiApiKey("TODO_PUT_KEY_HERE")
Expand All @@ -153,6 +158,7 @@ void testModelDoesntExistOpenAI() {

@Test
@EnabledIf("canRunOpenAiRelatedTests")
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void testFarmNoSiloOpenAI() {
OpenAiClient client = OpenAiClient.builder()
.openAiApiKey("TODO_PUT_KEY_HERE")
Expand All @@ -172,6 +178,7 @@ void testFarmNoSiloOpenAI() {

@Test
@EnabledIf("canRunOpenAiRelatedTests")
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void testFarmCompletionOpenAI() {
OpenAiClient client = OpenAiClient.builder()
.openAiApiKey("TODO_PUT_KEY_HERE")
Expand All @@ -192,6 +199,7 @@ void testFarmCompletionOpenAI() {

@Test
@EnabledIf("canRunOpenAiRelatedTests")
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void testFarmCompletionStreamOpenAI() {
OpenAiClient client = OpenAiClient.builder()
.openAiApiKey("TODO_PUT_KEY_HERE")
Expand All @@ -215,6 +223,7 @@ void testFarmCompletionStreamOpenAI() {

@Test
@EnabledIf("canRunOpenAiRelatedTests")
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
void testFarmChatRewriteQueryOpenAI() {
OpenAiClient client = OpenAiClient.builder()
.openAiApiKey("TODO_PUT_KEY_HERE")
Expand Down Expand Up @@ -251,6 +260,7 @@ void testFarmChatRewriteQueryOpenAI() {

@Test
@EnabledIf("canRunOpenAiRelatedTests")
@RetryOnExceptions(value = 2, onExceptions = { InterruptedIOException.class, SocketTimeoutException.class })
@SuppressWarnings("unchecked")
void testFarmCompletionWithMetadataFilterOpenAI() {
Map.of(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package ai.dragon.test.junit.extension.retry;

import org.junit.jupiter.api.extension.ExtendWith;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@ExtendWith(RetryOnExceptionsExtension.class)
public @interface RetryOnExceptions {
int value() default 3; // Default retry count
Class<? extends Throwable>[] onExceptions() default { Throwable.class }; // Exceptions to retry on
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package ai.dragon.test.junit.extension.retry;

import java.lang.reflect.Method;

import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ExtensionContext.Store;
import org.junit.jupiter.api.extension.TestExecutionExceptionHandler;

public class RetryOnExceptionsExtension implements TestExecutionExceptionHandler, BeforeEachCallback {

@Override
public void handleTestExecutionException(ExtensionContext context, Throwable throwable) throws Throwable {
Method testMethod = context.getRequiredTestMethod();
RetryOnExceptions retry = testMethod.getAnnotation(RetryOnExceptions.class);

if (retry != null) {
int maxRetries = retry.value();
Class<? extends Throwable>[] retryOnExceptions = retry.onExceptions();
Store store = context.getStore(ExtensionContext.Namespace.create(testMethod));

int currentRetries = store.getOrDefault("retries", Integer.class, 0);

boolean shouldRetry = false;
for (Class<? extends Throwable> retryOnException : retryOnExceptions) {
if (retryOnException.isInstance(throwable)) {
shouldRetry = true;
break;
}
}

if (shouldRetry && currentRetries < maxRetries) {
store.put("retries", currentRetries + 1);
context.getRequiredTestMethod().invoke(context.getRequiredTestInstance());
} else {
throw throwable;
}
} else {
throw throwable;
}
}

@Override
public void beforeEach(ExtensionContext context) throws Exception {
Method testMethod = context.getRequiredTestMethod();
if (testMethod.isAnnotationPresent(RetryOnExceptions.class)) {
Store store = context.getStore(ExtensionContext.Namespace.create(testMethod));
store.put("retries", 0); // Reset retry count before each test
}
}
}

0 comments on commit 556d996

Please sign in to comment.