Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor agent create function #2054

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 53 additions & 99 deletions letta/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,63 +101,55 @@ def wrapper(*args, **kwargs):

@retry_with_exponential_backoff
def create(
# agent_state: AgentState,
llm_config: LLMConfig,
messages: List[Message],
user_id: Optional[str] = None, # option UUID to associate request with
user_id: Optional[str] = None,
functions: Optional[list] = None,
functions_python: Optional[dict] = None,
function_call: str = "auto",
# hint
first_message: bool = False,
# use tool naming?
# if false, will use deprecated 'functions' style
use_tool_naming: bool = True,
# streaming?
stream: bool = False,
stream_interface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None,
max_tokens: Optional[int] = None,
model_settings: Optional[dict] = None, # TODO: eventually pass from server
model_settings: Optional[dict] = None,
) -> ChatCompletionResponse:
"""Return response to chat completion with backoff"""
"""Return response to chat completion with backoff."""
from letta.utils import printd

if not model_settings:
from letta.settings import model_settings

model_settings = model_settings

model_settings = model_settings
printd(f"Using model {llm_config.model_endpoint_type}, endpoint: {llm_config.model_endpoint}")

if function_call and not functions:
printd("unsetting function_call because functions is None")
printd("Unsetting function_call because functions is None")
function_call = None

# openai
if llm_config.model_endpoint_type == "openai":
def handle_openai():
if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
# only is a problem if we are *not* using an openai proxy
raise ValueError(f"OpenAI key is missing from letta config file")

data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens)
if stream: # Client requested token streaming
data.stream = True
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
stream_interface, AgentRefreshStreamingInterface
), type(stream_interface)
raise ValueError("OpenAI key is missing from letta config file")

data = build_openai_chat_completions_request(
llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens
)
data.stream = stream

if stream:
assert isinstance(stream_interface, (AgentChunkStreamingInterface, AgentRefreshStreamingInterface))
response = openai_chat_completions_process_stream(
url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
url=llm_config.model_endpoint,
api_key=model_settings.openai_api_key,
chat_completion_request=data,
stream_interface=stream_interface,
)
else: # Client did not request token streaming (expect a blocking backend response)
data.stream = False
else:
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_start()
try:
response = openai_chat_completions_request(
url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
url=llm_config.model_endpoint,
api_key=model_settings.openai_api_key,
chat_completion_request=data,
)
Expand All @@ -170,22 +162,13 @@ def create(

return response

# azure
elif llm_config.model_endpoint_type == "azure":
def handle_azure():
if stream:
raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}")
raise NotImplementedError("Streaming not yet implemented for Azure")

if model_settings.azure_api_key is None:
raise ValueError(f"Azure API key is missing. Did you set AZURE_API_KEY in your env?")
if not all([model_settings.azure_api_key, model_settings.azure_base_url, model_settings.azure_api_version]):
raise ValueError("Azure API key, base URL, or version is missing. Check your environment variables.")

if model_settings.azure_base_url is None:
raise ValueError(f"Azure base url is missing. Did you set AZURE_BASE_URL in your env?")

if model_settings.azure_api_version is None:
raise ValueError(f"Azure API version is missing. Did you set AZURE_API_VERSION in your env?")

# Set the llm config model_endpoint from model_settings
# For Azure, this model_endpoint is required to be configured via env variable, so users don't need to provide it in the LLM config
llm_config.model_endpoint = model_settings.azure_base_url
chat_completion_request = build_openai_chat_completions_request(
llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens
Expand All @@ -203,34 +186,28 @@ def create(

return response

elif llm_config.model_endpoint_type == "google_ai":
def handle_google_ai():
if stream:
raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}")
raise NotImplementedError("Streaming not yet implemented for Google AI")
if not use_tool_naming:
raise NotImplementedError("Only tool calling supported on Google AI API requests")

if functions is not None:
tools = [{"type": "function", "function": f} for f in functions]
tools = [Tool(**t) for t in tools]
tools = convert_tools_to_google_ai_format(tools, inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs)
else:
tools = None
tools = convert_tools_to_google_ai_format(
[{"type": "function", "function": f} for f in functions] if functions else None,
inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs
)

return google_ai_chat_completions_request(
base_url=llm_config.model_endpoint,
model=llm_config.model,
api_key=model_settings.gemini_api_key,
# see structure of payload here: https://ai.google.dev/docs/function_calling
data=dict(
contents=[m.to_google_ai_dict() for m in messages],
tools=tools,
),
data=dict(contents=[m.to_google_ai_dict() for m in messages], tools=tools),
inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
)

elif llm_config.model_endpoint_type == "anthropic":
def handle_anthropic():
if stream:
raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}")
raise NotImplementedError("Streaming not yet implemented for Anthropic")
if not use_tool_naming:
raise NotImplementedError("Only tool calling supported on Anthropic API requests")

Expand All @@ -241,56 +218,25 @@ def create(
model=llm_config.model,
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
# tool_choice=function_call,
# user=str(user_id),
# NOTE: max_tokens is required for Anthropic API
max_tokens=1024, # TODO make dynamic
max_tokens=1024,
),
)

# elif llm_config.model_endpoint_type == "cohere":
# if stream:
# raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}")
# if not use_tool_naming:
# raise NotImplementedError("Only tool calling supported on Cohere API requests")
#
# if functions is not None:
# tools = [{"type": "function", "function": f} for f in functions]
# tools = [Tool(**t) for t in tools]
# else:
# tools = None
#
# return cohere_chat_completions_request(
# # url=llm_config.model_endpoint,
# url="https://api.cohere.ai/v1", # TODO
# api_key=os.getenv("COHERE_API_KEY"), # TODO remove
# chat_completion_request=ChatCompletionRequest(
# model="command-r-plus", # TODO
# messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
# tools=tools,
# tool_choice=function_call,
# # user=str(user_id),
# # NOTE: max_tokens is required for Anthropic API
# # max_tokens=1024, # TODO make dynamic
# ),
# )

elif llm_config.model_endpoint_type == "groq":
def handle_groq():
if stream:
raise NotImplementedError(f"Streaming not yet implemented for Groq.")
raise NotImplementedError("Streaming not yet implemented for Groq")

if model_settings.groq_api_key is None and llm_config.model_endpoint == "https://api.groq.com/openai/v1/chat/completions":
raise ValueError(f"Groq key is missing from letta config file")
raise ValueError("Groq key is missing from letta config file")

# force to true for groq, since they don't support 'content' is non-null
if llm_config.put_inner_thoughts_in_kwargs:
functions = add_inner_thoughts_to_functions(
functions=functions,
inner_thoughts_key=INNER_THOUGHTS_KWARG,
inner_thoughts_description=INNER_THOUGHTS_KWARG_DESCRIPTION,
)

tools = [{"type": "function", "function": f} for f in functions] if functions is not None else None
tools = [{"type": "function", "function": f} for f in functions] if functions else None
data = ChatCompletionRequest(
model=llm_config.model,
messages=[m.to_openai_dict(put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs) for m in messages],
Expand All @@ -299,19 +245,15 @@ def create(
user=str(user_id),
)

# https://console.groq.com/docs/openai
# "The following fields are currently not supported and will result in a 400 error (yikes) if they are supplied:"
assert data.top_logprobs is None
assert data.logit_bias is None
assert data.logprobs == False
assert data.n == 1
# They mention that none of the messages can have names, but it seems to not error out (for now)

data.stream = False
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_start()
try:
# groq uses the openai chat completions API, so this component should be reusable
response = openai_chat_completions_request(
url=llm_config.model_endpoint,
api_key=model_settings.groq_api_key,
Expand All @@ -326,10 +268,9 @@ def create(

return response

# local model
else:
def handle_local():
if stream:
raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}")
raise NotImplementedError("Streaming not yet implemented for local models")
return get_chat_completion(
model=llm_config.model,
messages=messages,
Expand All @@ -341,9 +282,22 @@ def create(
endpoint_type=llm_config.model_endpoint_type,
wrapper=llm_config.model_wrapper,
user=str(user_id),
# hint
first_message=first_message,
# auth-related
auth_type=model_settings.openllm_auth_type,
auth_key=model_settings.openllm_api_key,
)

handlers = {
"openai": handle_openai,
"azure": handle_azure,
"google_ai": handle_google_ai,
"anthropic": handle_anthropic,
"groq": handle_groq,
"local": handle_local,
}

handler = handlers.get(llm_config.model_endpoint_type)
if handler:
return handler()
else:
raise NotImplementedError(f"Model endpoint type '{llm_config.model_endpoint_type}' is not supported.")
Loading