From 09dfb8aea6b2bc60309f9ba8e1f3b0c2acf1e897 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Fri, 15 Mar 2024 14:39:11 -0400 Subject: [PATCH 1/4] =?UTF-8?q?Tool=20=E2=86=92=20FunctionTool?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/marvin/_mappings/base_model.py | 6 +++--- src/marvin/_mappings/types.py | 4 ++-- src/marvin/beta/applications/applications.py | 4 ++-- src/marvin/beta/applications/state/state.py | 4 ++-- src/marvin/beta/assistants/assistants.py | 4 ++-- src/marvin/beta/assistants/runs.py | 6 +++--- src/marvin/tools/assistants.py | 4 ++-- src/marvin/types.py | 13 ++++++++----- src/marvin/utilities/tools.py | 10 +++++----- 9 files changed, 29 insertions(+), 26 deletions(-) diff --git a/src/marvin/_mappings/base_model.py b/src/marvin/_mappings/base_model.py index a45dbc3e6..1d5442931 100644 --- a/src/marvin/_mappings/base_model.py +++ b/src/marvin/_mappings/base_model.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from pydantic.json_schema import GenerateJsonSchema, JsonSchemaMode -from marvin.types import Function, Tool, ToolSet +from marvin.types import Function, FunctionTool, ToolSet class FunctionSchema(GenerateJsonSchema): @@ -15,10 +15,10 @@ def generate(self, schema: Any, mode: JsonSchemaMode = "validation"): def cast_model_to_tool( model: type[BaseModel], -) -> Tool[BaseModel]: +) -> FunctionTool[BaseModel]: model_name = model.__name__ model_description = model.__doc__ - return Tool[BaseModel]( + return FunctionTool[BaseModel]( type="function", function=Function[BaseModel]( name=model_name, diff --git a/src/marvin/_mappings/types.py b/src/marvin/_mappings/types.py index 0f1d0ce61..f4ac636ed 100644 --- a/src/marvin/_mappings/types.py +++ b/src/marvin/_mappings/types.py @@ -6,7 +6,7 @@ from pydantic.fields import FieldInfo from marvin.settings import settings -from marvin.types import Grammar, Tool, ToolSet +from marvin.types import FunctionTool, Grammar, ToolSet from .base_model import cast_model_to_tool, cast_model_to_toolset @@ -46,7 +46,7 @@ def cast_type_to_tool( field_name: str, field_description: str, python_function: Optional[Callable[..., Any]] = None, -) -> Tool[BaseModel]: +) -> FunctionTool[BaseModel]: return cast_model_to_tool( model=cast_type_to_model( _type, diff --git a/src/marvin/beta/applications/applications.py b/src/marvin/beta/applications/applications.py index 0c4545b64..20559cd92 100644 --- a/src/marvin/beta/applications/applications.py +++ b/src/marvin/beta/applications/applications.py @@ -6,7 +6,7 @@ from marvin.beta.assistants import Assistant from marvin.beta.assistants.runs import Run from marvin.tools.assistants import AssistantTool -from marvin.types import Tool +from marvin.types import FunctionTool from marvin.utilities.jinja import Environment as JinjaEnvironment from marvin.utilities.tools import tool_from_function @@ -66,7 +66,7 @@ def get_tools(self) -> list[AssistantTool]: tools = [] for tool in [self.state.as_tool(name="state")] + self.tools: - if not isinstance(tool, Tool): + if not isinstance(tool, FunctionTool): kwargs = None signature = inspect.signature(tool) for parameter in signature.parameters.values(): diff --git a/src/marvin/beta/applications/state/state.py b/src/marvin/beta/applications/state/state.py index 75bc8cdf5..a6bdc708c 100644 --- a/src/marvin/beta/applications/state/state.py +++ b/src/marvin/beta/applications/state/state.py @@ -5,7 +5,7 @@ from jsonpatch import JsonPatch from pydantic import BaseModel, Field, PrivateAttr, SerializeAsAny -from marvin.types import Tool +from marvin.types import FunctionTool from marvin.utilities.tools import tool_from_function @@ -66,7 +66,7 @@ def update_state_jsonpatches(self, patches: list[JSONPatchModel]): self.set_state(state) return "Application state updated successfully!" - def as_tool(self, name: str = None) -> "Tool": + def as_tool(self, name: str = None) -> "FunctionTool": if name is None: name = "state" schema = self.get_schema() diff --git a/src/marvin/beta/assistants/assistants.py b/src/marvin/beta/assistants/assistants.py index f0c562948..5ad27cf25 100644 --- a/src/marvin/beta/assistants/assistants.py +++ b/src/marvin/beta/assistants/assistants.py @@ -8,7 +8,7 @@ import marvin.utilities.openai import marvin.utilities.tools from marvin.tools.assistants import AssistantTool -from marvin.types import Tool +from marvin.types import FunctionTool from marvin.utilities.asyncio import ( ExposeSyncMethodsMixin, expose_sync_method, @@ -64,7 +64,7 @@ def get_tools(self) -> list[AssistantTool]: return [ ( tool - if isinstance(tool, Tool) + if isinstance(tool, FunctionTool) else marvin.utilities.tools.tool_from_function(tool) ) for tool in self.tools diff --git a/src/marvin/beta/assistants/runs.py b/src/marvin/beta/assistants/runs.py index 3e77c2266..58ca4f52a 100644 --- a/src/marvin/beta/assistants/runs.py +++ b/src/marvin/beta/assistants/runs.py @@ -11,7 +11,7 @@ import marvin.utilities.openai import marvin.utilities.tools from marvin.tools.assistants import AssistantTool, CancelRun -from marvin.types import Tool +from marvin.types import FunctionTool from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method from marvin.utilities.logging import get_logger @@ -61,12 +61,12 @@ class Run(BaseModel, ExposeSyncMethodsMixin): data: Any = None @field_validator("tools", "additional_tools", mode="before") - def format_tools(cls, tools: Union[None, list[Union[Tool, Callable]]]): + def format_tools(cls, tools: Union[None, list[Union[FunctionTool, Callable]]]): if tools is not None: return [ ( tool - if isinstance(tool, Tool) + if isinstance(tool, FunctionTool) else marvin.utilities.tools.tool_from_function(tool) ) for tool in tools diff --git a/src/marvin/tools/assistants.py b/src/marvin/tools/assistants.py index abe32c3cc..721665190 100644 --- a/src/marvin/tools/assistants.py +++ b/src/marvin/tools/assistants.py @@ -1,11 +1,11 @@ from typing import Any, Union -from marvin.types import CodeInterpreterTool, RetrievalTool, Tool +from marvin.types import CodeInterpreterTool, FunctionTool, RetrievalTool Retrieval = RetrievalTool() CodeInterpreter = CodeInterpreterTool() -AssistantTool = Union[RetrievalTool, CodeInterpreterTool, Tool] +AssistantTool = Union[RetrievalTool, CodeInterpreterTool, FunctionTool] class CancelRun(Exception): diff --git a/src/marvin/types.py b/src/marvin/types.py index 800d9e6db..bb106d790 100644 --- a/src/marvin/types.py +++ b/src/marvin/types.py @@ -60,21 +60,24 @@ def create( return instance -class Tool(MarvinType, Generic[T]): +class Tool(MarvinType): type: str + + +class FunctionTool(Tool, Generic[T]): function: Optional[Function[T]] = None class ToolSet(MarvinType, Generic[T]): - tools: Optional[list[Tool[T]]] = None + tools: Optional[list[Union[Tool, FunctionTool[T]]]] = None tool_choice: Optional[Union[Literal["auto"], dict[str, Any]]] = None -class RetrievalTool(Tool[T]): +class RetrievalTool(FunctionTool[T]): type: Literal["retrieval"] = "retrieval" -class CodeInterpreterTool(Tool[T]): +class CodeInterpreterTool(FunctionTool[T]): type: Literal["code_interpreter"] = "code_interpreter" @@ -244,7 +247,7 @@ class Run(MarvinType, Generic[T]): status: str model: str instructions: Optional[str] - tools: Optional[list[Tool[T]]] = None + tools: Optional[list[FunctionTool[T]]] = None metadata: dict[str, str] diff --git a/src/marvin/utilities/tools.py b/src/marvin/utilities/tools.py index f55f1caa0..27e9ae2ba 100644 --- a/src/marvin/utilities/tools.py +++ b/src/marvin/utilities/tools.py @@ -17,7 +17,7 @@ from pydantic.fields import FieldInfo from pydantic.json_schema import GenerateJsonSchema, JsonSchemaMode -from marvin.types import Function, Tool +from marvin.types import Function, FunctionTool from marvin.utilities.asyncio import run_sync from marvin.utilities.logging import get_logger @@ -63,7 +63,7 @@ def generate(self, schema: Any, mode: JsonSchemaMode = "validation"): return json_schema -def tool_from_type(type_: U, tool_name: str = None) -> Tool[U]: +def tool_from_type(type_: U, tool_name: str = None) -> FunctionTool[U]: """ Creates an OpenAI-compatible tool from a Python type. """ @@ -99,7 +99,7 @@ def tool_from_model(model: type[M], python_fn: Callable[[str], M] = None): def tool_fn(**data) -> M: return TypeAdapter(model).validate_python(data) - return Tool[M]( + return FunctionTool[M]( type="function", function=Function[M].create( name=model.__name__, @@ -130,7 +130,7 @@ def tool_from_function( fn, config=pydantic.ConfigDict(arbitrary_types_allowed=True) ).json_schema() - return Tool[T]( + return FunctionTool[T]( type="function", function=Function[T].create( name=name or fn.__name__, @@ -142,7 +142,7 @@ def tool_from_function( def call_function_tool( - tools: list[Tool], + tools: list[FunctionTool], function_name: str, function_arguments_json: str, return_string: bool = False, From 1bfc7ac2436159f1f98364673ff1fd00eb8f1dbd Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Fri, 15 Mar 2024 14:56:38 -0400 Subject: [PATCH 2/4] Ensure builtin tools can be used --- cookbook/slackbot/start.py | 6 ++-- src/marvin/beta/applications/applications.py | 4 +-- src/marvin/beta/assistants/assistants.py | 4 +-- src/marvin/beta/assistants/runs.py | 36 +++++++------------- src/marvin/beta/assistants/threads.py | 27 +++++---------- src/marvin/types.py | 4 +-- 6 files changed, 30 insertions(+), 51 deletions(-) diff --git a/cookbook/slackbot/start.py b/cookbook/slackbot/start.py index 2b3d0f5e6..444e65a29 100644 --- a/cookbook/slackbot/start.py +++ b/cookbook/slackbot/start.py @@ -142,13 +142,13 @@ async def handle_message(payload: SlackPayload) -> Completed: ai_response_text, "green", ) + messages = await assistant_thread.get_messages_async() + event = emit_assistant_completed_event( child_assistant=ai, parent_app=get_parent_app() if ENABLE_PARENT_APP else None, payload={ - "messages": await assistant_thread.get_messages_async( - json_compatible=True - ), + "messages": [m.model_dump() for m in messages], "metadata": assistant_thread.metadata, "user": { "id": event.user, diff --git a/src/marvin/beta/applications/applications.py b/src/marvin/beta/applications/applications.py index 20559cd92..0c4545b64 100644 --- a/src/marvin/beta/applications/applications.py +++ b/src/marvin/beta/applications/applications.py @@ -6,7 +6,7 @@ from marvin.beta.assistants import Assistant from marvin.beta.assistants.runs import Run from marvin.tools.assistants import AssistantTool -from marvin.types import FunctionTool +from marvin.types import Tool from marvin.utilities.jinja import Environment as JinjaEnvironment from marvin.utilities.tools import tool_from_function @@ -66,7 +66,7 @@ def get_tools(self) -> list[AssistantTool]: tools = [] for tool in [self.state.as_tool(name="state")] + self.tools: - if not isinstance(tool, FunctionTool): + if not isinstance(tool, Tool): kwargs = None signature = inspect.signature(tool) for parameter in signature.parameters.values(): diff --git a/src/marvin/beta/assistants/assistants.py b/src/marvin/beta/assistants/assistants.py index 5ad27cf25..f0c562948 100644 --- a/src/marvin/beta/assistants/assistants.py +++ b/src/marvin/beta/assistants/assistants.py @@ -8,7 +8,7 @@ import marvin.utilities.openai import marvin.utilities.tools from marvin.tools.assistants import AssistantTool -from marvin.types import FunctionTool +from marvin.types import Tool from marvin.utilities.asyncio import ( ExposeSyncMethodsMixin, expose_sync_method, @@ -64,7 +64,7 @@ def get_tools(self) -> list[AssistantTool]: return [ ( tool - if isinstance(tool, FunctionTool) + if isinstance(tool, Tool) else marvin.utilities.tools.tool_from_function(tool) ) for tool in self.tools diff --git a/src/marvin/beta/assistants/runs.py b/src/marvin/beta/assistants/runs.py index 58ca4f52a..cbf946ca1 100644 --- a/src/marvin/beta/assistants/runs.py +++ b/src/marvin/beta/assistants/runs.py @@ -6,12 +6,12 @@ ) from openai.types.beta.threads.run import Run as OpenAIRun from openai.types.beta.threads.runs import RunStep as OpenAIRunStep -from pydantic import BaseModel, Field, PrivateAttr, field_validator +from pydantic import BaseModel, Field, field_validator import marvin.utilities.openai import marvin.utilities.tools from marvin.tools.assistants import AssistantTool, CancelRun -from marvin.types import FunctionTool +from marvin.types import Tool from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method from marvin.utilities.logging import get_logger @@ -39,6 +39,7 @@ class Run(BaseModel, ExposeSyncMethodsMixin): data (Any): Any additional data associated with the run. """ + id: Optional[str] = None thread: Thread assistant: Assistant instructions: Optional[str] = Field( @@ -61,12 +62,12 @@ class Run(BaseModel, ExposeSyncMethodsMixin): data: Any = None @field_validator("tools", "additional_tools", mode="before") - def format_tools(cls, tools: Union[None, list[Union[FunctionTool, Callable]]]): + def format_tools(cls, tools: Union[None, list[Union[Tool, Callable]]]): if tools is not None: return [ ( tool - if isinstance(tool, FunctionTool) + if isinstance(tool, Tool) else marvin.utilities.tools.tool_from_function(tool) ) for tool in tools @@ -77,7 +78,7 @@ async def refresh_async(self): """Refreshes the run.""" client = marvin.utilities.openai.get_openai_client() self.run = await client.beta.threads.runs.retrieve( - run_id=self.run.id, thread_id=self.thread.id + run_id=self.run.id if self.run else self.id, thread_id=self.thread.id ) @expose_sync_method("cancel") @@ -85,7 +86,7 @@ async def cancel_async(self): """Cancels the run.""" client = marvin.utilities.openai.get_openai_client() await client.beta.threads.runs.cancel( - run_id=self.run.id, thread_id=self.thread.id + run_id=self.run.id if self.run else self.id, thread_id=self.thread.id ) async def _handle_step_requires_action( @@ -156,6 +157,10 @@ async def run_async(self) -> "Run": if self.tools is not None or self.additional_tools is not None: create_kwargs["tools"] = self.get_tools() + if self.id is not None: + raise ValueError( + "This run object was provided an ID; can not create a new run." + ) async with self.assistant: self.run = await client.beta.threads.runs.create( thread_id=self.thread.id, @@ -195,25 +200,10 @@ async def run_async(self) -> "Run": class RunMonitor(BaseModel): - run_id: str - thread_id: str - _run: Run = PrivateAttr() - _thread: Thread = PrivateAttr() + run: Run + thread: Thread steps: list[OpenAIRunStep] = [] - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._thread = Thread(**kwargs["thread_id"]) - self._run = Run(**kwargs["run_id"], thread=self.thread) - - @property - def thread(self): - return self._thread - - @property - def run(self): - return self._run - async def refresh_run_steps_async(self): """ Asynchronously refreshes and updates the run steps list. diff --git a/src/marvin/beta/assistants/threads.py b/src/marvin/beta/assistants/threads.py index 14d071cd2..bb508fcca 100644 --- a/src/marvin/beta/assistants/threads.py +++ b/src/marvin/beta/assistants/threads.py @@ -1,6 +1,6 @@ import asyncio import time -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional # for openai < 1.14.0 try: @@ -18,7 +18,6 @@ run_sync, ) from marvin.utilities.logging import get_logger -from marvin.utilities.pydantic import parse_as logger = get_logger("Threads") @@ -100,25 +99,18 @@ async def get_messages_async( limit: int = None, before_message: Optional[str] = None, after_message: Optional[str] = None, - json_compatible: bool = False, - ) -> list[Union[Message, dict]]: + ) -> list[Message]: """ Asynchronously retrieves messages from the thread. Args: limit (int, optional): The maximum number of messages to return. - before_message (str, optional): The ID of the message to start the list from, - retrieving messages sent before this one. - after_message (str, optional): The ID of the message to start the list from, - retrieving messages sent after this one. - json_compatible (bool, optional): If True, returns messages as dictionaries. - If False, returns messages as Message - objects. Default is False. - + before_message (str, optional): The ID of the message to start the + list from, retrieving messages sent before this one. + after_message (str, optional): The ID of the message to start the + list from, retrieving messages sent after this one. Returns: - list[Union[Message, dict]]: A list of messages from the thread, either - as dictionaries or Message objects, - depending on the value of json_compatible. + list[Union[Message, dict]]: A list of messages from the thread """ if self.id is None: @@ -134,10 +126,7 @@ async def get_messages_async( limit=limit, order="desc", ) - - T = dict if json_compatible else Message - - return parse_as(list[T], reversed(response.model_dump()["data"])) + return response.data @expose_sync_method("delete") async def delete_async(self): diff --git a/src/marvin/types.py b/src/marvin/types.py index bb106d790..304a41b34 100644 --- a/src/marvin/types.py +++ b/src/marvin/types.py @@ -73,11 +73,11 @@ class ToolSet(MarvinType, Generic[T]): tool_choice: Optional[Union[Literal["auto"], dict[str, Any]]] = None -class RetrievalTool(FunctionTool[T]): +class RetrievalTool(Tool): type: Literal["retrieval"] = "retrieval" -class CodeInterpreterTool(FunctionTool[T]): +class CodeInterpreterTool(Tool): type: Literal["code_interpreter"] = "code_interpreter" From 415d34c8c960db25cc8dbe34173ff8d379dd3b51 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Fri, 15 Mar 2024 15:00:42 -0400 Subject: [PATCH 3/4] Update assistants.py --- src/marvin/tools/assistants.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/marvin/tools/assistants.py b/src/marvin/tools/assistants.py index 721665190..abe32c3cc 100644 --- a/src/marvin/tools/assistants.py +++ b/src/marvin/tools/assistants.py @@ -1,11 +1,11 @@ from typing import Any, Union -from marvin.types import CodeInterpreterTool, FunctionTool, RetrievalTool +from marvin.types import CodeInterpreterTool, RetrievalTool, Tool Retrieval = RetrievalTool() CodeInterpreter = CodeInterpreterTool() -AssistantTool = Union[RetrievalTool, CodeInterpreterTool, FunctionTool] +AssistantTool = Union[RetrievalTool, CodeInterpreterTool, Tool] class CancelRun(Exception): From 3443dafc7a79e802b00183fcd17d97c1a868ac53 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Fri, 15 Mar 2024 15:16:47 -0400 Subject: [PATCH 4/4] Fix order of casting --- src/marvin/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/marvin/types.py b/src/marvin/types.py index 304a41b34..557c3fa3d 100644 --- a/src/marvin/types.py +++ b/src/marvin/types.py @@ -69,7 +69,7 @@ class FunctionTool(Tool, Generic[T]): class ToolSet(MarvinType, Generic[T]): - tools: Optional[list[Union[Tool, FunctionTool[T]]]] = None + tools: Optional[list[Union[FunctionTool[T], Tool]]] = None tool_choice: Optional[Union[Literal["auto"], dict[str, Any]]] = None