Skip to content

Commit

Permalink
Merge pull request #835 from salman1993/main
Browse files Browse the repository at this point in the history
Access to tool calls and tool outputs in post_run_hook
  • Loading branch information
zzstoatzz authored Feb 8, 2024
2 parents 8ee6b7c + ecda861 commit f786d4f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
10 changes: 9 additions & 1 deletion src/marvin/beta/assistants/assistants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import TYPE_CHECKING, Callable, Optional, Union

from openai.types.beta.threads.required_action_function_tool_call import (
RequiredActionFunctionToolCall,
)
from pydantic import BaseModel, Field, PrivateAttr

import marvin.utilities.tools
Expand Down Expand Up @@ -168,5 +171,10 @@ def chat(self, thread: Thread = None):
def pre_run_hook(self, run: "Run"):
pass

def post_run_hook(self, run: "Run"):
def post_run_hook(
self,
run: "Run",
tool_calls: Optional[list[RequiredActionFunctionToolCall]] = None,
tool_outputs: Optional[list[dict[str, str]]] = None,
):
pass
24 changes: 20 additions & 4 deletions src/marvin/beta/assistants/runs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
from typing import Any, Callable, Optional, Union

from openai.types.beta.threads.required_action_function_tool_call import (
RequiredActionFunctionToolCall,
)
from openai.types.beta.threads.run import Run as OpenAIRun
from openai.types.beta.threads.runs import RunStep as OpenAIRunStep
from pydantic import BaseModel, Field, PrivateAttr, field_validator
Expand Down Expand Up @@ -85,11 +88,14 @@ async def cancel_async(self):
run_id=self.run.id, thread_id=self.thread.id
)

async def _handle_step_requires_action(self):
async def _handle_step_requires_action(
self,
) -> tuple[list[RequiredActionFunctionToolCall], list[dict[str, str]]]:
client = get_openai_client()
if self.run.status != "requires_action":
return
return None, None
if self.run.required_action.type == "submit_tool_outputs":
tool_calls = []
tool_outputs = []
tools = self.get_tools()

Expand All @@ -110,10 +116,12 @@ async def _handle_step_requires_action(self):
tool_outputs.append(
dict(tool_call_id=tool_call.id, output=output or "")
)
tool_calls.append(tool_call)

await client.beta.threads.runs.submit_tool_outputs(
thread_id=self.thread.id, run_id=self.run.id, tool_outputs=tool_outputs
)
return tool_calls, tool_outputs

def get_instructions(self) -> str:
if self.instructions is None:
Expand Down Expand Up @@ -157,10 +165,16 @@ async def run_async(self) -> "Run":

self.assistant.pre_run_hook(run=self)

tool_calls = None
tool_outputs = None

try:
while self.run.status in ("queued", "in_progress", "requires_action"):
if self.run.status == "requires_action":
await self._handle_step_requires_action()
(
tool_calls,
tool_outputs,
) = await self._handle_step_requires_action()
await asyncio.sleep(0.1)
await self.refresh_async()
except CancelRun as exc:
Expand All @@ -174,7 +188,9 @@ async def run_async(self) -> "Run":
if self.run.status == "failed":
logger.debug(f"Run failed. Last error was: {self.run.last_error}")

self.assistant.post_run_hook(run=self)
self.assistant.post_run_hook(
run=self, tool_calls=tool_calls, tool_outputs=tool_outputs
)
return self


Expand Down

0 comments on commit f786d4f

Please sign in to comment.