diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 95f0e5ac63..118d10830e 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -101,63 +101,55 @@ def wrapper(*args, **kwargs): @retry_with_exponential_backoff def create( - # agent_state: AgentState, llm_config: LLMConfig, messages: List[Message], - user_id: Optional[str] = None, # option UUID to associate request with + user_id: Optional[str] = None, functions: Optional[list] = None, functions_python: Optional[dict] = None, function_call: str = "auto", - # hint first_message: bool = False, - # use tool naming? - # if false, will use deprecated 'functions' style use_tool_naming: bool = True, - # streaming? stream: bool = False, stream_interface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None, max_tokens: Optional[int] = None, - model_settings: Optional[dict] = None, # TODO: eventually pass from server + model_settings: Optional[dict] = None, ) -> ChatCompletionResponse: - """Return response to chat completion with backoff""" + """Return response to chat completion with backoff.""" from letta.utils import printd if not model_settings: from letta.settings import model_settings - model_settings = model_settings - + model_settings = model_settings printd(f"Using model {llm_config.model_endpoint_type}, endpoint: {llm_config.model_endpoint}") if function_call and not functions: - printd("unsetting function_call because functions is None") + printd("Unsetting function_call because functions is None") function_call = None - # openai - if llm_config.model_endpoint_type == "openai": + def handle_openai(): if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1": - # only is a problem if we are *not* using an openai proxy - raise ValueError(f"OpenAI key is missing from letta config file") - - data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens) - if stream: # Client requested token streaming - data.stream = True - assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance( - stream_interface, AgentRefreshStreamingInterface - ), type(stream_interface) + raise ValueError("OpenAI key is missing from letta config file") + + data = build_openai_chat_completions_request( + llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens + ) + data.stream = stream + + if stream: + assert isinstance(stream_interface, (AgentChunkStreamingInterface, AgentRefreshStreamingInterface)) response = openai_chat_completions_process_stream( - url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions + url=llm_config.model_endpoint, api_key=model_settings.openai_api_key, chat_completion_request=data, stream_interface=stream_interface, ) - else: # Client did not request token streaming (expect a blocking backend response) - data.stream = False + else: if isinstance(stream_interface, AgentChunkStreamingInterface): stream_interface.stream_start() try: response = openai_chat_completions_request( - url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions + url=llm_config.model_endpoint, api_key=model_settings.openai_api_key, chat_completion_request=data, ) @@ -170,22 +162,13 @@ def create( return response - # azure - elif llm_config.model_endpoint_type == "azure": + def handle_azure(): if stream: - raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}") + raise NotImplementedError("Streaming not yet implemented for Azure") - if model_settings.azure_api_key is None: - raise ValueError(f"Azure API key is missing. Did you set AZURE_API_KEY in your env?") + if not all([model_settings.azure_api_key, model_settings.azure_base_url, model_settings.azure_api_version]): + raise ValueError("Azure API key, base URL, or version is missing. Check your environment variables.") - if model_settings.azure_base_url is None: - raise ValueError(f"Azure base url is missing. Did you set AZURE_BASE_URL in your env?") - - if model_settings.azure_api_version is None: - raise ValueError(f"Azure API version is missing. Did you set AZURE_API_VERSION in your env?") - - # Set the llm config model_endpoint from model_settings - # For Azure, this model_endpoint is required to be configured via env variable, so users don't need to provide it in the LLM config llm_config.model_endpoint = model_settings.azure_base_url chat_completion_request = build_openai_chat_completions_request( llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens @@ -203,34 +186,28 @@ def create( return response - elif llm_config.model_endpoint_type == "google_ai": + def handle_google_ai(): if stream: - raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}") + raise NotImplementedError("Streaming not yet implemented for Google AI") if not use_tool_naming: raise NotImplementedError("Only tool calling supported on Google AI API requests") - if functions is not None: - tools = [{"type": "function", "function": f} for f in functions] - tools = [Tool(**t) for t in tools] - tools = convert_tools_to_google_ai_format(tools, inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs) - else: - tools = None + tools = convert_tools_to_google_ai_format( + [{"type": "function", "function": f} for f in functions] if functions else None, + inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs + ) return google_ai_chat_completions_request( base_url=llm_config.model_endpoint, model=llm_config.model, api_key=model_settings.gemini_api_key, - # see structure of payload here: https://ai.google.dev/docs/function_calling - data=dict( - contents=[m.to_google_ai_dict() for m in messages], - tools=tools, - ), + data=dict(contents=[m.to_google_ai_dict() for m in messages], tools=tools), inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs, ) - elif llm_config.model_endpoint_type == "anthropic": + def handle_anthropic(): if stream: - raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}") + raise NotImplementedError("Streaming not yet implemented for Anthropic") if not use_tool_naming: raise NotImplementedError("Only tool calling supported on Anthropic API requests") @@ -241,48 +218,17 @@ def create( model=llm_config.model, messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], tools=[{"type": "function", "function": f} for f in functions] if functions else None, - # tool_choice=function_call, - # user=str(user_id), - # NOTE: max_tokens is required for Anthropic API - max_tokens=1024, # TODO make dynamic + max_tokens=1024, ), ) - # elif llm_config.model_endpoint_type == "cohere": - # if stream: - # raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}") - # if not use_tool_naming: - # raise NotImplementedError("Only tool calling supported on Cohere API requests") - # - # if functions is not None: - # tools = [{"type": "function", "function": f} for f in functions] - # tools = [Tool(**t) for t in tools] - # else: - # tools = None - # - # return cohere_chat_completions_request( - # # url=llm_config.model_endpoint, - # url="https://api.cohere.ai/v1", # TODO - # api_key=os.getenv("COHERE_API_KEY"), # TODO remove - # chat_completion_request=ChatCompletionRequest( - # model="command-r-plus", # TODO - # messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], - # tools=tools, - # tool_choice=function_call, - # # user=str(user_id), - # # NOTE: max_tokens is required for Anthropic API - # # max_tokens=1024, # TODO make dynamic - # ), - # ) - - elif llm_config.model_endpoint_type == "groq": + def handle_groq(): if stream: - raise NotImplementedError(f"Streaming not yet implemented for Groq.") + raise NotImplementedError("Streaming not yet implemented for Groq") if model_settings.groq_api_key is None and llm_config.model_endpoint == "https://api.groq.com/openai/v1/chat/completions": - raise ValueError(f"Groq key is missing from letta config file") + raise ValueError("Groq key is missing from letta config file") - # force to true for groq, since they don't support 'content' is non-null if llm_config.put_inner_thoughts_in_kwargs: functions = add_inner_thoughts_to_functions( functions=functions, @@ -290,7 +236,7 @@ def create( inner_thoughts_description=INNER_THOUGHTS_KWARG_DESCRIPTION, ) - tools = [{"type": "function", "function": f} for f in functions] if functions is not None else None + tools = [{"type": "function", "function": f} for f in functions] if functions else None data = ChatCompletionRequest( model=llm_config.model, messages=[m.to_openai_dict(put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs) for m in messages], @@ -299,19 +245,15 @@ def create( user=str(user_id), ) - # https://console.groq.com/docs/openai - # "The following fields are currently not supported and will result in a 400 error (yikes) if they are supplied:" assert data.top_logprobs is None assert data.logit_bias is None assert data.logprobs == False assert data.n == 1 - # They mention that none of the messages can have names, but it seems to not error out (for now) data.stream = False if isinstance(stream_interface, AgentChunkStreamingInterface): stream_interface.stream_start() try: - # groq uses the openai chat completions API, so this component should be reusable response = openai_chat_completions_request( url=llm_config.model_endpoint, api_key=model_settings.groq_api_key, @@ -326,10 +268,9 @@ def create( return response - # local model - else: + def handle_local(): if stream: - raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}") + raise NotImplementedError("Streaming not yet implemented for local models") return get_chat_completion( model=llm_config.model, messages=messages, @@ -341,9 +282,22 @@ def create( endpoint_type=llm_config.model_endpoint_type, wrapper=llm_config.model_wrapper, user=str(user_id), - # hint first_message=first_message, - # auth-related auth_type=model_settings.openllm_auth_type, auth_key=model_settings.openllm_api_key, ) + + handlers = { + "openai": handle_openai, + "azure": handle_azure, + "google_ai": handle_google_ai, + "anthropic": handle_anthropic, + "groq": handle_groq, + "local": handle_local, + } + + handler = handlers.get(llm_config.model_endpoint_type) + if handler: + return handler() + else: + raise NotImplementedError(f"Model endpoint type '{llm_config.model_endpoint_type}' is not supported.") \ No newline at end of file