Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update Tool schemas and improve provider integration #2349

Merged
merged 20 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
71031e9
feat: Extend tool runs to also take in environment variables (#554)
mattzh72 Jan 9, 2025
f29fb9d
feat: Add `return_char_limit` to `ToolUpdate` (#557)
mattzh72 Jan 9, 2025
cb86e88
fix: Write unit test for ensuring e2b composio version is valid (#560)
mattzh72 Jan 9, 2025
46e1fe6
chore: Merge OSS (#562)
mattzh72 Jan 9, 2025
23bc188
feat: Add sandbox_type filter when listing sandbox configurations (#567)
mattzh72 Jan 9, 2025
d008a04
chore: Deprecate O1 Agent (#573)
mattzh72 Jan 9, 2025
e85b5a4
feat: Add ToolType enum (#584)
mattzh72 Jan 10, 2025
2a92c57
feat: add lazy id generation for providers (#543)
carenthomas Jan 10, 2025
b80af68
feat: add updated_at timestamp to provider and bump on write (#574)
carenthomas Jan 10, 2025
111fcaa
chore: generate sdk + docs for provider api (#544)
carenthomas Jan 10, 2025
92ddb1d
fix: tag matching (#585)
mlong93 Jan 10, 2025
5861e8d
feat: v1 desktop ui (#586)
4shub Jan 10, 2025
696940e
feat: Add model integration testing (#587)
mattzh72 Jan 10, 2025
14a13b8
feat: Add `tool_type` column (#576)
mattzh72 Jan 10, 2025
629db6a
feat: Adjust tool execution sandbox to be under ~/.letta (#591)
mattzh72 Jan 10, 2025
199d09e
fix: Remove unique name restriction on agents (#592)
mattzh72 Jan 10, 2025
7380beb
fix: Adjust type of `args` to any for `ToolRunFromSource` (#593)
mattzh72 Jan 11, 2025
e1d0089
fix: Fix offline memory test (#597)
mattzh72 Jan 11, 2025
f1ed6cf
chore: Remove `deprecated_tool` mentions everywhere (#598)
mattzh72 Jan 11, 2025
2eb0735
Merge branch 'main' into update-schemas
sarahwooders Jan 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Remove unique name restriction on agents

Revision ID: cdb3db091113
Revises: e20573fe9b86
Create Date: 2025-01-10 15:36:08.728539

"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "cdb3db091113"
down_revision: Union[str, None] = "e20573fe9b86"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("unique_org_agent_name", "agents", type_="unique")
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_unique_constraint("unique_org_agent_name", "agents", ["organization_id", "name"])
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@
depends_on: Union[str, Sequence[str], None] = None


def deprecated_tool():
return "this is a deprecated tool, please remove it from your tools list"


def upgrade() -> None:
# Delete all tools
op.execute("DELETE FROM tools")
Expand Down
70 changes: 70 additions & 0 deletions alembic/versions/e20573fe9b86_add_tool_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Add tool types

Revision ID: e20573fe9b86
Revises: 915b68780108
Create Date: 2025-01-09 15:11:47.779646

"""

from typing import Sequence, Union

import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

from alembic import op
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from letta.orm.enums import ToolType

# revision identifiers, used by Alembic.
revision: str = "e20573fe9b86"
down_revision: Union[str, None] = "915b68780108"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Step 1: Add the column as nullable with no default
op.add_column("tools", sa.Column("tool_type", sa.String(), nullable=True))

# Step 2: Backpopulate the tool_type column based on tool name
# Define the list of Letta core tools
letta_core_value = ToolType.LETTA_CORE.value
letta_memory_core_value = ToolType.LETTA_MEMORY_CORE.value
custom_value = ToolType.CUSTOM.value

# Update tool_type for Letta core tools
op.execute(
f"""
UPDATE tools
SET tool_type = '{letta_core_value}'
WHERE name IN ({','.join(f"'{name}'" for name in BASE_TOOLS)});
"""
)

op.execute(
f"""
UPDATE tools
SET tool_type = '{letta_memory_core_value}'
WHERE name IN ({','.join(f"'{name}'" for name in BASE_MEMORY_TOOLS)});
"""
)

# Update tool_type for all other tools
op.execute(
f"""
UPDATE tools
SET tool_type = '{custom_value}'
WHERE tool_type IS NULL;
"""
)

# Step 3: Alter the column to be non-nullable
op.alter_column("tools", "tool_type", nullable=False)
op.alter_column("tools", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)


def downgrade() -> None:
# Revert the changes made during the upgrade
op.alter_column("tools", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
op.drop_column("tools", "tool_type")
# ### end Alembic commands ###
50 changes: 29 additions & 21 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,27 @@
from typing import List, Optional, Tuple, Union

from letta.constants import (
BASE_TOOLS,
CLI_WARNING_PREFIX,
ERROR_MESSAGE_PREFIX,
FIRST_MESSAGE_ATTEMPTS,
FUNC_FAILED_HEARTBEAT_MESSAGE,
LETTA_CORE_TOOL_MODULE_NAME,
LLM_MAX_TOKENS,
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
MESSAGE_SUMMARY_WARNING_FRAC,
REQ_HEARTBEAT_MESSAGE,
)
from letta.errors import ContextWindowExceededError
from letta.functions.functions import get_function_from_module
from letta.helpers import ToolRulesSolver
from letta.interface import AgentInterface
from letta.llm_api.helpers import is_context_overflow_error
from letta.llm_api.llm_api_tools import create
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.memory import summarize_messages
from letta.orm import User
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
Expand Down Expand Up @@ -153,7 +155,7 @@ def load_last_function_response(self):
raise ValueError(f"Invalid JSON format in message: {msg.text}")
return None

def update_memory_if_change(self, new_memory: Memory) -> bool:
def update_memory_if_changed(self, new_memory: Memory) -> bool:
"""
Update internal memory object and system prompt if there have been modifications.

Expand Down Expand Up @@ -192,39 +194,45 @@ def execute_tool_and_persist_state(self, function_name: str, function_args: dict
Execute tool modifications and persist the state of the agent.
Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data
"""
# TODO: Get rid of this. This whole piece is pretty shady, that we exec the function to just get the type hints for args.
env = {}
env.update(globals())
exec(target_letta_tool.source_code, env)
callable_func = env[target_letta_tool.json_schema["name"]]
spec = inspect.getfullargspec(callable_func).annotations
for name, arg in function_args.items():
if isinstance(function_args[name], dict):
function_args[name] = spec[name](**function_args[name])

# TODO: add agent manager here
orig_memory_str = self.agent_state.memory.compile()

# TODO: need to have an AgentState object that actually has full access to the block data
# this is because the sandbox tools need to be able to access block.value to edit this data
try:
# TODO: This is NO BUENO
# TODO: Matching purely by names is extremely problematic, users can create tools with these names and run them in the agent loop
# TODO: We will have probably have to match the function strings exactly for safety
if function_name in BASE_TOOLS:
if target_letta_tool.tool_type == ToolType.LETTA_CORE:
# base tools are allowed to access the `Agent` object and run on the database
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
function_args["self"] = self # need to attach self to arg since it's dynamically linked
function_response = callable_func(**function_args)
elif target_letta_tool.tool_type == ToolType.LETTA_MEMORY_CORE:
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
agent_state_copy = self.agent_state.__deepcopy__()
function_args["agent_state"] = agent_state_copy # need to attach self to arg since it's dynamically linked
function_response = callable_func(**function_args)
self.update_memory_if_changed(agent_state_copy.memory)
else:
# TODO: Get rid of this. This whole piece is pretty shady, that we exec the function to just get the type hints for args.
env = {}
env.update(globals())
exec(target_letta_tool.source_code, env)
callable_func = env[target_letta_tool.json_schema["name"]]
spec = inspect.getfullargspec(callable_func).annotations
for name, arg in function_args.items():
if isinstance(function_args[name], dict):
function_args[name] = spec[name](**function_args[name])

# execute tool in a sandbox
# TODO: allow agent_state to specify which sandbox to execute tools in
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user).run(
agent_state=self.agent_state.__deepcopy__()
)
# TODO: This is only temporary, can remove after we publish a pip package with this object
agent_state_copy = self.agent_state.__deepcopy__()
agent_state_copy.tools = []

sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user).run(agent_state=agent_state_copy)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
if updated_agent_state is not None:
self.update_memory_if_change(updated_agent_state.memory)
self.update_memory_if_changed(updated_agent_state.memory)
except Exception as e:
# Need to catch error here, or else trunction wont happen
# TODO: modify to function execution error
Expand Down Expand Up @@ -677,7 +685,7 @@ def inner_step(
current_persisted_memory = Memory(
blocks=[self.block_manager.get_block_by_id(block.id, actor=self.user) for block in self.agent_state.memory.get_blocks()]
) # read blocks from DB
self.update_memory_if_change(current_persisted_memory)
self.update_memory_if_changed(current_persisted_memory)

# Step 1: add user message
if isinstance(messages, Message):
Expand Down
4 changes: 4 additions & 0 deletions letta/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING

LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta")
LETTA_DIR_TOOL_SANDBOX = os.path.join(LETTA_DIR, "tool_sandbox_dir")

ADMIN_PREFIX = "/v1/admin"
API_PREFIX = "/v1"
OPENAI_API_PREFIX = "/openai"

COMPOSIO_ENTITY_ENV_VAR_KEY = "COMPOSIO_ENTITY"
COMPOSIO_TOOL_TAG_NAME = "composio"

LETTA_CORE_TOOL_MODULE_NAME = "letta.functions.function_sets.base"

# String in the error message for when the context window is too large
# Example full message:
Expand Down
65 changes: 65 additions & 0 deletions letta/functions/functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import inspect
from textwrap import dedent # remove indentation
from types import ModuleType
Expand Down Expand Up @@ -64,6 +65,70 @@ def parse_source_code(func) -> str:
return source_code


def get_function_from_module(module_name: str, function_name: str):
"""
Dynamically imports a function from a specified module.

Args:
module_name (str): The name of the module to import (e.g., 'base').
function_name (str): The name of the function to retrieve.

Returns:
Callable: The imported function.

Raises:
ModuleNotFoundError: If the specified module cannot be found.
AttributeError: If the function is not found in the module.
"""
try:
# Dynamically import the module
module = importlib.import_module(module_name)
# Retrieve the function
return getattr(module, function_name)
except ModuleNotFoundError:
raise ModuleNotFoundError(f"Module '{module_name}' not found.")
except AttributeError:
raise AttributeError(f"Function '{function_name}' not found in module '{module_name}'.")


def get_json_schema_from_module(module_name: str, function_name: str) -> dict:
"""
Dynamically loads a specific function from a module and generates its JSON schema.

Args:
module_name (str): The name of the module to import (e.g., 'base').
function_name (str): The name of the function to retrieve.

Returns:
dict: The JSON schema for the specified function.

Raises:
ModuleNotFoundError: If the specified module cannot be found.
AttributeError: If the function is not found in the module.
ValueError: If the attribute is not a user-defined function.
"""
try:
# Dynamically import the module
module = importlib.import_module(module_name)

# Retrieve the function
attr = getattr(module, function_name, None)

# Check if it's a user-defined function
if not (inspect.isfunction(attr) and attr.__module__ == module.__name__):
raise ValueError(f"'{function_name}' is not a user-defined function in module '{module_name}'")

# Generate schema (assuming a `generate_schema` function exists)
generated_schema = generate_schema(attr)

return generated_schema

except ModuleNotFoundError:
raise ModuleNotFoundError(f"Module '{module_name}' not found.")
except AttributeError:
raise AttributeError(f"Function '{function_name}' not found in module '{module_name}'.")


def load_function_set(module: ModuleType) -> dict:
"""Load the functions and generate schema for them, given a module object"""
function_dict = {}
Expand Down
Loading
Loading