Skip to content

Commit

Permalink
Chat Stream Response
Browse files Browse the repository at this point in the history
  • Loading branch information
amengus87 committed Jun 5, 2024
1 parent 021c5b7 commit c978062
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import java.util.List;
import java.util.UUID;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import ai.dragon.dto.openai.completion.OpenAiChatCompletionChoice;
import ai.dragon.dto.openai.completion.OpenAiChatCompletionRequest;
Expand All @@ -18,6 +18,7 @@
import ai.dragon.dto.openai.completion.OpenAiCompletionRequest;
import ai.dragon.dto.openai.completion.OpenAiCompletionResponse;
import ai.dragon.dto.openai.completion.OpenAiCompletionUsage;
import ai.dragon.service.SseService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
Expand All @@ -26,6 +27,9 @@
@RequestMapping("/api/ragapi/v1")
@Tag(name = "Open AI Compatible", description = "Compatible Endpoints following Open AI API Format")
public class OpenAiCompatibleV1ApiController {
@Autowired
private SseService sseService;

@PostMapping("/completions")
@Operation(summary = "Creates a completion", description = "Creates a completion for the provided prompt and parameters.")
public OpenAiCompletionResponse completions(@Valid @RequestBody OpenAiCompletionRequest request)
Expand Down Expand Up @@ -59,39 +63,59 @@ public OpenAiCompletionResponse completions(@Valid @RequestBody OpenAiCompletion
@PostMapping("/chat/completions")
@Operation(summary = "Creates a chat completion", description = "Creates a chat completion for the provided prompt and parameters.")
public Object chatCompletions(@Valid @RequestBody OpenAiChatCompletionRequest request) throws Exception {
OpenAiChatCompletionResponse response = new OpenAiChatCompletionResponse();

response.setId(UUID.randomUUID().toString());
response.setModel(request.getModel());
response.setCreated(System.currentTimeMillis() / 1000);
response.setObject("chat.completion");
response.setUsage(OpenAiCompletionUsage
.builder()
.completion_tokens(0)
.prompt_tokens(0)
.total_tokens(0)
.build());
if (Boolean.TRUE.equals(request.getStream())) {
UUID emitterID = sseService.createEmitter();

List<OpenAiChatCompletionChoice> choices = new ArrayList<>();
choices.add(OpenAiChatCompletionChoice
.builder()
.index(0)
.finish_reason("stop")
.message(OpenAiCompletionMessage
for (int i = 0; i < 3; i++) {
OpenAiChatCompletionResponse responseChunk = new OpenAiChatCompletionResponse();
responseChunk.setId(emitterID.toString());
responseChunk.setModel(request.getModel());
responseChunk.setCreated(System.currentTimeMillis() / 1000);
responseChunk.setObject("chat.completion.chunk");
List<OpenAiChatCompletionChoice> choices = new ArrayList<>();
choices.add(OpenAiChatCompletionChoice
.builder()
.role("assistant")
.content("Hello, how can I help you today?")
.build())
.build());

response.setChoices(choices);

if (request.getStream() != null && request.getStream()) {
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
emitter.send("okok");
return emitter;
.index(0)
.finish_reason(i == 2 ? "stop" : null)
.delta(OpenAiCompletionMessage
.builder()
.role("assistant")
.content("Chunk : " + i + "\r\n")
.build())
.build());
responseChunk.setChoices(choices);
sseService.sendEvent(emitterID, responseChunk);
}
sseService.sendEvent(emitterID, "[DONE]");
new Thread(() -> {
sseService.complete(emitterID);
}).start();
return sseService.retrieveEmitter(emitterID);
} else {
OpenAiChatCompletionResponse response = new OpenAiChatCompletionResponse();
response.setId(UUID.randomUUID().toString());
response.setModel(request.getModel());
response.setCreated(System.currentTimeMillis() / 1000);
response.setObject("chat.completion");
response.setUsage(OpenAiCompletionUsage
.builder()
.completion_tokens(0)
.prompt_tokens(0)
.total_tokens(0)
.build());
List<OpenAiChatCompletionChoice> choices = new ArrayList<>();
choices.add(OpenAiChatCompletionChoice
.builder()
.index(0)
.finish_reason("stop")
.message(OpenAiCompletionMessage
.builder()
.role("assistant")
.content("Hello, how can I help you today?")
.build())
.build());
response.setChoices(choices);
return response;
}

return response;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ public class OpenAiChatCompletionChoice {
private String finish_reason;
private Integer index;
private OpenAiCompletionMessage message;
private OpenAiCompletionMessage delta;
}
62 changes: 62 additions & 0 deletions backend/src/main/java/ai/dragon/service/SseService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package ai.dragon.service;

import java.io.IOException;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

@Service
public class SseService {
private final Logger logger = LoggerFactory.getLogger(this.getClass());

private ConcurrentHashMap<UUID, SseEmitter> emitters = new ConcurrentHashMap<>();

public UUID createEmitter() {
return createEmitter(10L * 1000L);
}

public UUID createEmitter(Long timeout) {
UUID id = UUID.randomUUID();
SseEmitter emitter = new SseEmitter(timeout);
addEmitter(id, emitter);
return id;
}

public void addEmitter(UUID id, SseEmitter emitter) {
emitters.put(id, emitter);
emitter.onCompletion(() -> emitters.remove(id));
emitter.onTimeout(() -> emitters.remove(id));
}

public SseEmitter retrieveEmitter(UUID id) {
return emitters.get(id);
}

public void complete(UUID id) {
SseEmitter emitter = emitters.get(id);
if (emitter == null) {
logger.warn("No emitter found for id: {}", id);
return;
}
emitter.complete();
emitters.remove(id);
}

public void sendEvent(UUID id, Object event) {
SseEmitter emitter = emitters.get(id);
if (emitter == null) {
logger.warn("No emitter found for id: {}", id);
return;
}
try {
emitter.send(event);
} catch (IOException e) {
emitter.complete();
emitters.remove(id);
}
}
}

0 comments on commit c978062

Please sign in to comment.