Skip to content

Commit

Permalink
fix(health.md): add rerank model health check information (#7295)
Browse files Browse the repository at this point in the history
* fix(health.md): add rerank model health check information

* build(model_prices_and_context_window.json): add gemini 2.0 for google ai studio - pricing + commercial rate limits

* build(model_prices_and_context_window.json): add gemini-2.0 supports audio output = true

* docs(team_model_add.md): clarify allowing teams to add models is an enterprise feature

* fix(o1_transformation.py): add support for 'n', 'response_format' and 'stop' params for o1 and 'stream_options' param for o1-mini

* build(model_prices_and_context_window.json): add 'supports_system_message' to supporting openai models

needed as o1-preview, and o1-mini models don't support 'system message

* fix(o1_transformation.py): translate system message based on if o1 model supports it

* fix(o1_transformation.py): return 'stream' param support if o1-mini/o1-preview

o1 currently doesn't support streaming, but the other model versions do

Fixes #7292

* fix(o1_transformation.py): return tool calling/response_format in supported params if model map says so

Fixes #7292

* fix: fix linting errors

* fix: update '_transform_messages'

* fix(o1_transformation.py): fix provider passed for supported param checks

* test(base_llm_unit_tests.py): skip test if api takes >5s to respond

* fix(utils.py): return false in 'supports_factory' if can't find value

* fix(o1_transformation.py): always return stream + stream_options as supported params + handle stream options being passed in for azure o1

* feat(openai.py): support stream faking natively in openai handler

Allows o1 calls to be faked for just the "o1" model, allows native streaming for o1-mini, o1-preview

 Fixes #7292

* fix(openai.py): use inference param instead of original optional param
  • Loading branch information
krrishdholakia authored Dec 19, 2024
1 parent 6a45ee1 commit 5253f63
Show file tree
Hide file tree
Showing 34 changed files with 665 additions and 380 deletions.
14 changes: 14 additions & 0 deletions docs/my-website/docs/proxy/health.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ model_list:
mode: audio_speech
```

### Rerank Models

To run rerank health checks, specify the mode as "rerank" in your config for the relevant model.

```yaml
model_list:
- model_name: rerank-english-v3.0
litellm_params:
model: cohere/rerank-english-v3.0
api_key: os.environ/COHERE_API_KEY
model_info:
mode: rerank
```

### Batch Models (Azure Only)

For Azure models deployed as 'batch' models, set `mode: batch`.
Expand Down
11 changes: 10 additions & 1 deletion docs/my-website/docs/proxy/team_model_add.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# Allow Teams to Add Models
# ✨ Allow Teams to Add Models

:::info

This is an Enterprise feature.
[Enterprise Pricing](https://www.litellm.ai/#pricing)

[Contact us here to get a free trial](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)

:::

Allow team to add a their own models/key for that project - so any OpenAI call they make uses their OpenAI key.

Expand Down
4 changes: 3 additions & 1 deletion litellm/litellm_core_utils/prompt_templates/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3144,7 +3144,9 @@ def prompt_factory(
else:
return gemini_text_image_pt(messages=messages)
elif custom_llm_provider == "mistral":
return litellm.MistralConfig()._transform_messages(messages=messages)
return litellm.MistralConfig()._transform_messages(
messages=messages, model=model
)
elif custom_llm_provider == "bedrock":
if "amazon.titan-text" in model:
return amazon_titan_pt(messages=messages)
Expand Down
6 changes: 0 additions & 6 deletions litellm/llms/anthropic/completion/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,6 @@ def _get_anthropic_text_prompt_from_messages(

return str(prompt)

def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"Not required"
raise NotImplementedError

def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
Expand Down
2 changes: 2 additions & 0 deletions litellm/llms/azure/chat/o1_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def completion(
client=None,
):
stream: Optional[bool] = optional_params.pop("stream", False)
stream_options: Optional[dict] = optional_params.pop("stream_options", None)
response = super().completion(
model,
messages,
Expand Down Expand Up @@ -90,6 +91,7 @@ def completion(
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
stream_options=stream_options,
)

return streaming_response
Expand Down
3 changes: 2 additions & 1 deletion litellm/llms/azure_ai/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import litellm
from litellm._logging import verbose_logger
from litellm.llms.openai.openai import OpenAIConfig
from litellm.litellm_core_utils.prompt_templates.common_utils import (
_audio_or_image_in_message_content,
convert_content_list_to_str,
)
from litellm.llms.openai.openai import OpenAIConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ProviderField
Expand All @@ -33,6 +33,7 @@ def get_required_params(self) -> List[ProviderField]:
def _transform_messages(
self,
messages: List[AllMessageValues],
model: str,
) -> List:
"""
- Azure AI Studio doesn't support content as a list. This handles:
Expand Down
8 changes: 8 additions & 0 deletions litellm/llms/base_llm/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ def get_config(cls):
and v is not None
}

def should_fake_stream(
self, model: str, custom_llm_provider: Optional[str] = None
) -> bool:
"""
Returns True if the model/provider should fake stream
"""
return False

@abstractmethod
def get_supported_openai_params(self, model: str) -> list:
pass
Expand Down
5 changes: 0 additions & 5 deletions litellm/llms/clarifai/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,6 @@ def validate_environment(
headers["Authorization"] = f"Bearer {api_key}"
return headers

def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
raise NotImplementedError

def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
Expand Down
5 changes: 0 additions & 5 deletions litellm/llms/cloudflare/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,6 @@ def get_error_class(
message=error_message,
)

def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
raise NotImplementedError

def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
Expand Down
5 changes: 0 additions & 5 deletions litellm/llms/cohere/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,3 @@ def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereError(status_code=status_code, message=error_message)

def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
raise NotImplementedError
6 changes: 0 additions & 6 deletions litellm/llms/cohere/completion/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,6 @@ def validate_environment(
api_key=api_key,
)

def _transform_messages(
self,
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
raise NotImplementedError

def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
Expand Down
7 changes: 5 additions & 2 deletions litellm/llms/databricks/chat/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
Handles the chat completion request for Databricks
"""

from typing import Any, Callable, Literal, Optional, Tuple, Union
from typing import Any, Callable, List, Literal, Optional, Tuple, Union, cast

from httpx._config import Timeout

from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import CustomStreamingDecoder
from litellm.utils import ModelResponse

Expand Down Expand Up @@ -44,7 +45,9 @@ def completion(
streaming_decoder: Optional[CustomStreamingDecoder] = None,
fake_stream: bool = False,
):
messages = DatabricksConfig()._transform_messages(messages) # type: ignore
messages = DatabricksConfig()._transform_messages(
messages=cast(List[AllMessageValues], messages), model=model
)
api_base, headers = self.databricks_validate_environment(
api_base=api_base,
api_key=api_key,
Expand Down
12 changes: 6 additions & 6 deletions litellm/llms/databricks/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

from pydantic import BaseModel

from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ProviderField

from ...openai_like.chat.transformation import OpenAILikeChatConfig
from litellm.litellm_core_utils.prompt_templates.common_utils import (
handle_messages_with_content_list_to_str_conversion,
strip_name_from_messages,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ProviderField

from ...openai_like.chat.transformation import OpenAILikeChatConfig


class DatabricksConfig(OpenAILikeChatConfig):
Expand Down Expand Up @@ -86,7 +86,7 @@ def _should_fake_stream(self, optional_params: dict) -> bool:
return False

def _transform_messages(
self, messages: List[AllMessageValues]
self, messages: List[AllMessageValues], model: str
) -> List[AllMessageValues]:
"""
Databricks does not support:
Expand All @@ -102,4 +102,4 @@ def _transform_messages(
new_messages.append(_message)
new_messages = handle_messages_with_content_list_to_str_conversion(new_messages)
new_messages = strip_name_from_messages(new_messages)
return super()._transform_messages(new_messages)
return super()._transform_messages(messages=new_messages, model=model)
10 changes: 5 additions & 5 deletions litellm/llms/deepseek/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,26 @@
from pydantic import BaseModel

import litellm
from litellm.litellm_core_utils.prompt_templates.common_utils import (
handle_messages_with_content_list_to_str_conversion,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage

from ....utils import _remove_additional_properties, _remove_strict_from_schema
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.litellm_core_utils.prompt_templates.common_utils import (
handle_messages_with_content_list_to_str_conversion,
)


class DeepSeekChatConfig(OpenAIGPTConfig):

def _transform_messages(
self, messages: List[AllMessageValues]
self, messages: List[AllMessageValues], model: str
) -> List[AllMessageValues]:
"""
DeepSeek does not support content in list format.
"""
messages = handle_messages_with_content_list_to_str_conversion(messages)
return super()._transform_messages(messages)
return super()._transform_messages(messages=messages, model=model)

def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
Expand Down
7 changes: 5 additions & 2 deletions litellm/llms/groq/chat/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
Handles the chat completion request for groq
"""

from typing import Any, Callable, Optional, Union
from typing import Any, Callable, List, Optional, Union, cast

from httpx._config import Timeout

from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import CustomStreamingDecoder
from litellm.utils import ModelResponse

Expand Down Expand Up @@ -42,7 +43,9 @@ def completion(
streaming_decoder: Optional[CustomStreamingDecoder] = None,
fake_stream: bool = False,
):
messages = GroqChatConfig()._transform_messages(messages) # type: ignore
messages = GroqChatConfig()._transform_messages(
messages=cast(List[AllMessageValues], messages), model=model
)

if optional_params.get("stream") is True:
fake_stream = GroqChatConfig()._should_fake_stream(optional_params)
Expand Down
2 changes: 1 addition & 1 deletion litellm/llms/groq/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
def get_config(cls):
return super().get_config()

def _transform_messages(self, messages: List[AllMessageValues]) -> List:
def _transform_messages(self, messages: List[AllMessageValues], model: str) -> List:
for idx, message in enumerate(messages):
"""
1. Don't pass 'null' function_call assistant message to groq - https://github.com/BerriAI/litellm/issues/5839
Expand Down
6 changes: 0 additions & 6 deletions litellm/llms/huggingface/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,12 +369,6 @@ def validate_environment(
headers = {**headers, **default_headers}
return headers

def _transform_messages(
self,
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
return messages

def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
Expand Down
4 changes: 2 additions & 2 deletions litellm/llms/mistral/mistral_chat_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import types
from typing import List, Literal, Optional, Tuple, Union

from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.litellm_core_utils.prompt_templates.common_utils import (
handle_messages_with_content_list_to_str_conversion,
strip_none_values_from_message,
)
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues

Expand Down Expand Up @@ -148,7 +148,7 @@ def _get_openai_compatible_provider_info(
return api_base, dynamic_api_key

def _transform_messages(
self, messages: List[AllMessageValues]
self, messages: List[AllMessageValues], model: str
) -> List[AllMessageValues]:
"""
- handles scenario where content is list and not string
Expand Down
11 changes: 3 additions & 8 deletions litellm/llms/ollama/completion/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from litellm.types.utils import (
GenericStreamingChunk,
ModelInfo,
ModelInfoBase,
ModelResponse,
ProviderField,
StreamingChoices,
Expand Down Expand Up @@ -198,7 +199,7 @@ def _get_max_tokens(self, ollama_model_info: dict) -> Optional[int]:
return v
return None

def get_model_info(self, model: str) -> ModelInfo:
def get_model_info(self, model: str) -> ModelInfoBase:
"""
curl http://localhost:11434/api/show -d '{
"name": "mistral"
Expand All @@ -222,11 +223,10 @@ def get_model_info(self, model: str) -> ModelInfo:

_max_tokens: Optional[int] = self._get_max_tokens(model_info)

return ModelInfo(
return ModelInfoBase(
key=model,
litellm_provider="ollama",
mode="chat",
supported_openai_params=self.get_supported_openai_params(model=model),
supports_function_calling=self._supports_function_calling(model_info),
input_cost_per_token=0.0,
output_cost_per_token=0.0,
Expand All @@ -235,11 +235,6 @@ def get_model_info(self, model: str) -> ModelInfo:
max_output_tokens=_max_tokens,
)

def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
return messages

def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
Expand Down
5 changes: 0 additions & 5 deletions litellm/llms/oobabooga/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@


class OobaboogaConfig(OpenAIGPTConfig):
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
return messages

def get_error_class(
self,
error_message: str,
Expand Down
2 changes: 1 addition & 1 deletion litellm/llms/openai/chat/gpt_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def map_openai_params(
)

def _transform_messages(
self, messages: List[AllMessageValues]
self, messages: List[AllMessageValues], model: str
) -> List[AllMessageValues]:
return messages

Expand Down
Loading

0 comments on commit 5253f63

Please sign in to comment.