From 015019b0cf1934bd31ca03ea607a643764c05df9 Mon Sep 17 00:00:00 2001 From: Arnaud Mengus Date: Wed, 5 Jun 2024 15:37:43 +0000 Subject: [PATCH] Completions Stream Chunks --- .../OpenAiCompatibleV1ApiController.java | 72 ++++++++++++------- 1 file changed, 46 insertions(+), 26 deletions(-) diff --git a/backend/src/main/java/ai/dragon/controller/api/ragapi/OpenAiCompatibleV1ApiController.java b/backend/src/main/java/ai/dragon/controller/api/ragapi/OpenAiCompatibleV1ApiController.java index 4c6d9e9f..bf3cf955 100644 --- a/backend/src/main/java/ai/dragon/controller/api/ragapi/OpenAiCompatibleV1ApiController.java +++ b/backend/src/main/java/ai/dragon/controller/api/ragapi/OpenAiCompatibleV1ApiController.java @@ -32,32 +32,53 @@ public class OpenAiCompatibleV1ApiController { @PostMapping("/completions") @Operation(summary = "Creates a completion", description = "Creates a completion for the provided prompt and parameters.") - public OpenAiCompletionResponse completions(@Valid @RequestBody OpenAiCompletionRequest request) + public Object completions(@Valid @RequestBody OpenAiCompletionRequest request) throws Exception { - OpenAiCompletionResponse response = new OpenAiCompletionResponse(); - - response.setId(UUID.randomUUID().toString()); - response.setModel(request.getModel()); - response.setCreated(System.currentTimeMillis() / 1000); - response.setObject("text_completion"); - response.setUsage(OpenAiCompletionUsage - .builder() - .completion_tokens(0) - .prompt_tokens(0) - .total_tokens(0) - .build()); - - List choices = new ArrayList<>(); - choices.add(OpenAiCompletionChoice - .builder() - .index(0) - .finish_reason("stop") - .text("Hello, how can I help you today?") - .build()); - - response.setChoices(choices); - - return response; + if (Boolean.TRUE.equals(request.getStream())) { + UUID emitterID = sseService.createEmitter(); + for (int i = 0; i < 3; i++) { + OpenAiCompletionResponse responseChunk = new OpenAiCompletionResponse(); + responseChunk.setId(emitterID.toString()); + responseChunk.setModel(request.getModel()); + responseChunk.setCreated(System.currentTimeMillis() / 1000); + responseChunk.setObject("text_completion"); + List choices = new ArrayList<>(); + choices.add(OpenAiCompletionChoice + .builder() + .index(0) + .finish_reason(i == 2 ? "stop" : null) + .text("Chunk : " + i + "\r\n") + .build()); + responseChunk.setChoices(choices); + sseService.sendEvent(emitterID, responseChunk); + } + sseService.sendEvent(emitterID, "[DONE]"); + new Thread(() -> { + sseService.complete(emitterID); + }).start(); + return sseService.retrieveEmitter(emitterID); + } else { + OpenAiCompletionResponse response = new OpenAiCompletionResponse(); + response.setId(UUID.randomUUID().toString()); + response.setModel(request.getModel()); + response.setCreated(System.currentTimeMillis() / 1000); + response.setObject("text_completion"); + response.setUsage(OpenAiCompletionUsage + .builder() + .completion_tokens(0) + .prompt_tokens(0) + .total_tokens(0) + .build()); + List choices = new ArrayList<>(); + choices.add(OpenAiCompletionChoice + .builder() + .index(0) + .finish_reason("stop") + .text("Hello, how can I help you today?") + .build()); + response.setChoices(choices); + return response; + } } @PostMapping("/chat/completions") @@ -65,7 +86,6 @@ public OpenAiCompletionResponse completions(@Valid @RequestBody OpenAiCompletion public Object chatCompletions(@Valid @RequestBody OpenAiChatCompletionRequest request) throws Exception { if (Boolean.TRUE.equals(request.getStream())) { UUID emitterID = sseService.createEmitter(); - for (int i = 0; i < 3; i++) { OpenAiChatCompletionResponse responseChunk = new OpenAiChatCompletionResponse(); responseChunk.setId(emitterID.toString());