Skip to content

Commit

Permalink
Merge pull request #941 from andehr/tool-choice
Browse files Browse the repository at this point in the history
Support `tool_choice` parameter for runs
  • Loading branch information
zzstoatzz authored Jun 26, 2024
2 parents 882039f + 7d819db commit 1735b7d
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion src/marvin/beta/assistants/runs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Literal, Optional, Union

from openai import AsyncAssistantEventHandler
from openai.types.beta.threads import Message
Expand Down Expand Up @@ -34,6 +34,8 @@ class Run(BaseModel, ExposeSyncMethodsMixin):
for the run.
additional_tools (list[AssistantTool], optional): Additional tools to append
to the assistant's tools.
tool_choice (Union[Literal["auto", "none", "required"], AssistantTool], optional):
The tool use behaviour for the run.
run (OpenAIRun): The OpenAI run object.
data (Any): Any additional data associated with the run.
"""
Expand Down Expand Up @@ -67,6 +69,12 @@ class Run(BaseModel, ExposeSyncMethodsMixin):
None,
description="Additional tools to append to the assistant's tools. ",
)
tool_choice: Optional[
Union[Literal["none", "auto", "required"], AssistantTool]
] = Field(
default=None,
description="The tool use behaviour for the run. Can be 'none', 'auto', 'required', or a specific tool.",
)
run: OpenAIRun = Field(None, repr=False)
data: Any = None

Expand Down Expand Up @@ -154,6 +162,13 @@ def _get_run_kwargs(self, thread: Thread = None, **run_kwargs) -> dict:
if model := self._get_model():
run_kwargs["model"] = model

if tool_choice := self.tool_choice:
run_kwargs["tool_choice"] = (
tool_choice.model_dump(mode="json")
if isinstance(tool_choice, Tool)
else tool_choice
)

return run_kwargs

async def get_tool_outputs(self, run: OpenAIRun) -> list[str]:
Expand Down

0 comments on commit 1735b7d

Please sign in to comment.