Skip to content

Commit

Permalink
Fix : Adding finish_reason on last chunk (#47)
Browse files Browse the repository at this point in the history
* Adding CheckStyle module for banning OpenAI API Keys

* RaaG : Begin with langchain4j assistant

* Init RaagService to call OpenAI LLM

* Fix OpenAiCompatibleV1ApiControllerTest

* Adding finish_reason in order to fix : [OpenAIClient.chatCompletion][stream] Missing finish_reason

* Increase SSE timeout from 10 to 90 secs

* Increase OpenAI client timeout from 60 to 90 secs
  • Loading branch information
amengus87 authored Jun 11, 2024
1 parent 6aca365 commit 6b2b044
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,30 +116,29 @@ public Object chatCompletions(@Valid @RequestBody OpenAiChatCompletionRequest re
if (Boolean.TRUE.equals(request.getStream())) {
return raagService.chatResponse(farm, request);
} else {
OpenAiChatCompletionResponse response = new OpenAiChatCompletionResponse();
response.setId(UUID.randomUUID().toString());
response.setModel(request.getModel());
response.setCreated(System.currentTimeMillis() / 1000);
response.setObject("chat.completion");
response.setUsage(OpenAiCompletionUsage
return OpenAiChatCompletionResponse
.builder()
.completion_tokens(0)
.prompt_tokens(0)
.total_tokens(0)
.build());
List<OpenAiChatCompletionChoice> choices = new ArrayList<>();
choices.add(OpenAiChatCompletionChoice
.builder()
.index(0)
.finish_reason("stop")
.message(OpenAiCompletionMessage
.id(UUID.randomUUID().toString())
.model(request.getModel())
.created(System.currentTimeMillis() / 1000)
.object("chat.completion")
.usage(OpenAiCompletionUsage
.builder()
.role("assistant")
.content("Hello, how can I help you today?")
.completion_tokens(0)
.prompt_tokens(0)
.total_tokens(0)
.build())
.build());
response.setChoices(choices);
return response;
.choices(List.of(OpenAiChatCompletionChoice
.builder()
.index(0)
.finish_reason("stop")
.message(OpenAiCompletionMessage
.builder()
.role("assistant")
.content("Hello, how can I help you today?")
.build())
.build()))
.build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import java.util.List;

import lombok.Builder;
import lombok.Data;

@Data
@Builder
public class OpenAiChatCompletionResponse {
private String id;
private String object;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package ai.dragon.enumeration;

import java.time.Duration;

import ai.dragon.dto.llm.StreamingChatLanguageModelDefinition;
import ai.dragon.service.SseService;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;

public enum LanguageModelType {
Expand Down Expand Up @@ -36,6 +39,7 @@ public StreamingChatLanguageModelDefinition getStreamingChatLanguageModel() thro
.builder()
.apiKey(parameters.getApiKey())
.modelName(parameters.getModelName())
.timeout(Duration.ofSeconds(SseService.DEFAULT_TIMEOUT))
.build();
})
.providerType(ProviderType.OpenAI)
Expand Down
52 changes: 31 additions & 21 deletions backend/src/main/java/ai/dragon/service/RaagService.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,28 +77,14 @@ public SseEmitter chatResponse(FarmEntity farm, OpenAiChatCompletionRequest requ
// TODO request.getMessages().get(0).getContent())
TokenStream stream = assistant.chat((String) request.getMessages().get(0).getContent());
UUID emitterID = sseService.createEmitter();
stream.onNext(nextChunk -> {
OpenAiChatCompletionResponse responseChunk = new OpenAiChatCompletionResponse();
responseChunk.setId(emitterID.toString());
responseChunk.setModel(request.getModel());
responseChunk.setCreated(System.currentTimeMillis() / 1000);
responseChunk.setObject("chat.completion.chunk");
List<OpenAiChatCompletionChoice> choices = new ArrayList<>();
choices.add(OpenAiChatCompletionChoice
.builder()
.index(0)
// .finish_reason(i == 2 ? "stop" : null)
.delta(OpenAiCompletionMessage
.builder()
.role("assistant")
.content(nextChunk)
.build())
.build());
responseChunk.setChoices(choices);
sseService.sendEvent(emitterID, responseChunk);
})
stream
.onNext(nextChunk -> {
sseService.sendEvent(emitterID,
this.createChatCompletionResponse(emitterID, request, nextChunk, false));
})
.onComplete(response -> {
// response.finishReason(). // "stop"
sseService.sendEvent(emitterID,
this.createChatCompletionResponse(emitterID, request, "", true));
sseService.sendEvent(emitterID, "[DONE]");
sseService.complete(emitterID);
})
Expand All @@ -107,6 +93,30 @@ public SseEmitter chatResponse(FarmEntity farm, OpenAiChatCompletionRequest requ
return sseService.retrieveEmitter(emitterID);
}

private OpenAiChatCompletionResponse createChatCompletionResponse(
UUID emitterID,
OpenAiChatCompletionRequest request,
String nextChunk,
boolean isLastChunk) {
return OpenAiChatCompletionResponse
.builder()
.id(emitterID.toString())
.model(request.getModel())
.created(System.currentTimeMillis() / 1000)
.object("chat.completion.chunk")
.choices(List.of(OpenAiChatCompletionChoice
.builder()
.index(0)
.finish_reason(isLastChunk ? "stop" : null)
.delta(OpenAiCompletionMessage
.builder()
.role("assistant")
.content(nextChunk)
.build())
.build()))
.build();
}

private StreamingChatLanguageModel buildStreamingChatLanguageModel(FarmEntity farm) throws Exception {
return farm
.getLanguageModel()
Expand Down
21 changes: 14 additions & 7 deletions backend/src/main/java/ai/dragon/service/SseService.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@

@Service
public class SseService {
// TODO Custom timeout :
public static final long DEFAULT_TIMEOUT = 90L * 1000;

private final Logger logger = LoggerFactory.getLogger(this.getClass());

private ConcurrentHashMap<UUID, SseEmitter> emitters = new ConcurrentHashMap<>();

public UUID createEmitter() {
return createEmitter(10L * 1000L);
return createEmitter(DEFAULT_TIMEOUT);
}

public UUID createEmitter(Long timeout) {
Expand All @@ -36,27 +39,31 @@ public SseEmitter retrieveEmitter(UUID id) {
return emitters.get(id);
}

public void complete(UUID id) {
public boolean complete(UUID id) {
SseEmitter emitter = emitters.get(id);
if (emitter == null) {
logger.warn("No emitter found for id: {}", id);
return;
logger.warn("Can't complete : No emitter found for id '{}'", id);
return false;
}
emitter.complete();
emitters.remove(id);
return true;
}

public void sendEvent(UUID id, Object event) {
public boolean sendEvent(UUID id, Object event) {
SseEmitter emitter = emitters.get(id);
if (emitter == null) {
logger.warn("No emitter found for id: {}", id);
return;
logger.info(event.toString());
logger.warn("Can't send event : No emitter found for id '{}'", id);
return false;
}
try {
emitter.send(event);
} catch (IOException e) {
emitter.complete();
emitters.remove(id);
return false;
}
return true;
}
}

0 comments on commit 6b2b044

Please sign in to comment.