Skip to content

Commit

Permalink
feat: add print for sqlite error (#2221)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Dec 11, 2024
1 parent 5ae6d69 commit 65702c8
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 6 deletions.
75 changes: 70 additions & 5 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, N
raise NotImplementedError


from contextlib import contextmanager

from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

Expand All @@ -166,6 +171,37 @@ def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, N

config = LettaConfig.load()


def print_sqlite_schema_error():
"""Print a formatted error message for SQLite schema issues"""
console = Console()
error_text = Text()
error_text.append("Existing SQLite DB schema is invalid, and schema migrations are not supported for SQLite. ", style="bold red")
error_text.append("To have migrations supported between Letta versions, please run Letta with Docker (", style="white")
error_text.append("https://docs.letta.com/server/docker", style="blue underline")
error_text.append(") or use Postgres by setting ", style="white")
error_text.append("LETTA_PG_URI", style="yellow")
error_text.append(".\n\n", style="white")
error_text.append("If you wish to keep using SQLite, you can reset your database by removing the DB file with ", style="white")
error_text.append("rm ~/.letta/sqlite.db", style="yellow")
error_text.append(" or downgrade to your previous version of Letta.", style="white")

console.print(Panel(error_text, border_style="red"))


@contextmanager
def db_error_handler():
"""Context manager for handling database errors"""
try:
yield
except Exception as e:
# Handle other SQLAlchemy errors
print(e)
print_sqlite_schema_error()
# raise ValueError(f"SQLite DB error: {str(e)}")
exit(1)


if settings.letta_pg_uri_no_default:
config.recall_storage_type = "postgres"
config.recall_storage_uri = settings.letta_pg_uri_no_default
Expand All @@ -178,6 +214,30 @@ def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, N
# TODO: don't rely on config storage
engine = create_engine("sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db"))

# Store the original connect method
original_connect = engine.connect

def wrapped_connect(*args, **kwargs):
with db_error_handler():
# Get the connection
connection = original_connect(*args, **kwargs)

# Store the original execution method
original_execute = connection.execute

# Wrap the execute method of the connection
def wrapped_execute(*args, **kwargs):
with db_error_handler():
return original_execute(*args, **kwargs)

# Replace the connection's execute method
connection.execute = wrapped_execute

return connection

# Replace the engine's connect method
engine.connect = wrapped_connect

Base.metadata.create_all(bind=engine)

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Expand Down Expand Up @@ -379,7 +439,9 @@ def initialize_agent(self, agent_id, interface: Union[AgentInterface, None] = No
if agent_state.agent_type == AgentType.memgpt_agent:
agent = Agent(agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence)
elif agent_state.agent_type == AgentType.offline_memory_agent:
agent = OfflineMemoryAgent(agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence)
agent = OfflineMemoryAgent(
agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence
)
else:
assert initial_message_sequence is None, f"Initial message sequence is not supported for O1Agents"
agent = O1Agent(agent_state=agent_state, interface=interface, user=actor)
Expand Down Expand Up @@ -500,8 +562,8 @@ def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStati
letta_agent.attach_source(
user=self.user_manager.get_user_by_id(user_id=user_id),
source_id=data_source,
source_manager=letta_agent.source_manager,
ms=self.ms
source_manager=letta_agent.source_manager,
ms=self.ms,
)

elif command.lower() == "dump" or command.lower().startswith("dump "):
Expand Down Expand Up @@ -1267,7 +1329,10 @@ def get_agent_archival_cursor(

# iterate over records
records = letta_agent.passage_manager.list_passages(
actor=self.default_user, agent_id=agent_id, cursor=cursor, limit=limit,
actor=self.default_user,
agent_id=agent_id,
cursor=cursor,
limit=limit,
)
return records

Expand Down Expand Up @@ -1914,7 +1979,7 @@ def run_tool_from_source(
date=get_utc_time(),
status="error",
function_return=error_msg,
stdout=[''],
stdout=[""],
stderr=[traceback.format_exc()],
)

Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ pathvalidate = "^3.2.1"
langchain-community = {version = "^0.3.7", optional = true}
langchain = {version = "^0.3.7", optional = true}
sentry-sdk = {extras = ["fastapi"], version = "2.19.1"}
rich = "^13.9.4"
brotli = "^1.1.0"
grpcio = "^1.68.1"
grpcio-tools = "^1.68.1"
Expand Down

0 comments on commit 65702c8

Please sign in to comment.