Skip to content

Commit

Permalink
Merge pull request #188 from PrefectHQ/improve-streaming-response
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin authored Apr 8, 2023
2 parents 8a05941 + 4978687 commit e018bdf
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 32 deletions.
6 changes: 5 additions & 1 deletion src/marvin/bot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,11 @@ async def say(
else:
raise ValueError(f"Unknown on_error value: {on_error}")
else:
raise RuntimeError("Failed to validate response after 3 attempts")
response = (
"Error: could not validate response after"
f" {MAX_VALIDATION_ATTEMPTS} attempts."
)
parsed_response = response

if validated:
parsed_response = self.response_format.parse_response(response)
Expand Down
66 changes: 35 additions & 31 deletions src/marvin/cli/tui.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import json
import logging
import re
import warnings
from functools import partial
from typing import Optional

import dotenv
Expand Down Expand Up @@ -36,6 +37,8 @@
handlers=[TextualHandler()],
)

USING_PLUGIN_REGEX = re.compile(r'{\s*"action":\s*"run-plugin",\s*"name":\s*"(.*?)"')


@marvin.ai_fn(llm_model_name="gpt-3.5-turbo", llm_model_temperature=1)
async def name_thread(history: str, personality: str, current_name: str = None) -> str:
Expand Down Expand Up @@ -166,14 +169,19 @@ class ResponseHover(Message):


class ResponseBody(Markdown):
pass
text: str = ""

def update(self, markdown: str):
self.text = markdown
super().update(markdown)

def on_enter(self):
self.post_message(ResponseHover())


class Response(Container):
body = None
stream_finished: bool = False

def __init__(self, message: marvin.models.threads.Message, **kwargs) -> None:
classes = kwargs.setdefault("classes", "")
Expand Down Expand Up @@ -292,7 +300,6 @@ def clear_responses(self) -> None:
for response in responses:
response.remove()
self.bot_name = getattr(self.app.bot, "name")
print(self.bot_name)
empty = self.query_one("Conversation #empty-thread-container")
empty.remove_class("hidden")

Expand Down Expand Up @@ -608,34 +615,21 @@ async def on_button_pressed(self, event: Button.Pressed) -> None:
elif event.button.id == "quit":
self.app.exit()

async def update_last_bot_response(self, token_buffer: list[str]):
async def stream_bot_response(self, token_buffer: list[str], response: BotResponse):
streaming_response = "".join(token_buffer)
responses = self.query("Response")
if responses:
response = responses.last()
if not isinstance(response, BotResponse):
conversation = self.query_one("Conversation", Conversation)
await conversation.add_response(
BotResponse(
marvin.models.threads.Message(
role="bot",
name=self.app.bot.name,
bot_id=self.app.bot.id,
content=streaming_response,
)
)
)
else:
# the bot is going to use a plugin
if match := marvin.bot.base.PLUGIN_REGEX.search(streaming_response):
try:
plugin_name = json.loads(match.group(1))["name"]
response.body.update(f'Using plugin "{plugin_name}"...')
except Exception:
response.body.update("Using plugin...")
else:
response.message.content = streaming_response
response.body.update(streaming_response)

if not self.app.is_mounted(response):
conversation = self.query_one("Conversation", Conversation)
await conversation.add_response(response)

# the bot is going to use a plugin
if match := USING_PLUGIN_REGEX.search(streaming_response):
plugin_name = match.group(1)
if not response.body.text == f'Using plugin "{plugin_name}"...':
response.body.update(f'Using plugin "{plugin_name}"...')
else:
response.message.content = streaming_response
response.body.update(streaming_response)

# scroll to bottom
messages = self.query_one("Conversation #messages", VerticalScroll)
Expand All @@ -646,9 +640,19 @@ async def get_bot_response(self, event: Input.Submitted) -> str:
bot = self.app.bot
self.app.bot_responding = True
try:
bot_response = BotResponse(
marvin.models.threads.Message(
role="bot",
name=self.app.bot.name,
bot_id=self.app.bot.id,
content="",
)
)
response = await bot.say(
event.value,
on_token_callback=self.update_last_bot_response,
on_token_callback=partial(
self.stream_bot_response, response=bot_response
),
)

self.query_one("Conversation", Conversation)
Expand Down

0 comments on commit e018bdf

Please sign in to comment.