Skip to content

Commit

Permalink
feat(openai.py): support stream faking natively in openai handler
Browse files Browse the repository at this point in the history
Allows o1 calls to be faked for just the "o1" model, allows native streaming for o1-mini, o1-preview

 Fixes #7292
  • Loading branch information
krrishdholakia committed Dec 19, 2024
1 parent a082a72 commit a702326
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 37 deletions.
8 changes: 8 additions & 0 deletions litellm/llms/base_llm/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ def get_config(cls):
and v is not None
}

def should_fake_stream(
self, model: str, custom_llm_provider: Optional[str] = None
) -> bool:
"""
Returns True if the model/provider should fake stream
"""
return False

@abstractmethod
def get_supported_openai_params(self, model: str) -> list:
pass
Expand Down
9 changes: 9 additions & 0 deletions litellm/llms/openai/chat/o1_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ class OpenAIO1Config(OpenAIGPTConfig):
def get_config(cls):
return super().get_config()

def should_fake_stream(
self, model: str, custom_llm_provider: Optional[str] = None
) -> bool:
supported_stream_models = ["o1-mini", "o1-preview"]
for supported_model in supported_stream_models:
if supported_model in model:
return False
return True

def get_supported_openai_params(self, model: str) -> list:
"""
Get the supported OpenAI params for the given model
Expand Down
79 changes: 73 additions & 6 deletions litellm/llms/openai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
prompt_factory,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import (
Expand Down Expand Up @@ -410,6 +411,24 @@ def make_sync_openai_chat_completion_request(
else:
raise e

def mock_streaming(
self,
response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
model: str,
stream_options: Optional[dict] = None,
) -> CustomStreamWrapper:
completion_stream = MockResponseIterator(model_response=response)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
stream_options=stream_options,
)

return streaming_response

def completion( # type: ignore # noqa: PLR0915
self,
model_response: ModelResponse,
Expand All @@ -433,8 +452,21 @@ def completion( # type: ignore # noqa: PLR0915
):
super().completion()
try:
fake_stream: bool = False
if custom_llm_provider is not None and model is not None:
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders(custom_llm_provider)
)
fake_stream = provider_config.should_fake_stream(
model=model, custom_llm_provider=custom_llm_provider
)
inference_params = optional_params.copy()
stream_options: Optional[dict] = inference_params.pop(
"stream_options", None
)
stream: Optional[bool] = inference_params.pop("stream", False)
if headers:
optional_params["extra_headers"] = headers
inference_params["extra_headers"] = headers
if model is None or messages is None:
raise OpenAIError(status_code=422, message="Missing model or messages")

Expand Down Expand Up @@ -466,15 +498,15 @@ def completion( # type: ignore # noqa: PLR0915
data = OpenAIConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
optional_params=inference_params,
litellm_params=litellm_params,
headers=headers or {},
)

try:
max_retries = data.pop("max_retries", 2)
if acompletion is True:
if optional_params.get("stream", False):
if stream is True and fake_stream is False:
return self.async_streaming(
logging_obj=logging_obj,
headers=headers,
Expand All @@ -487,11 +519,13 @@ def completion( # type: ignore # noqa: PLR0915
max_retries=max_retries,
organization=organization,
drop_params=drop_params,
stream_options=stream_options,
)
else:
return self.acompletion(
data=data,
headers=headers,
model=model,
logging_obj=logging_obj,
model_response=model_response,
api_base=api_base,
Expand All @@ -501,8 +535,9 @@ def completion( # type: ignore # noqa: PLR0915
max_retries=max_retries,
organization=organization,
drop_params=drop_params,
fake_stream=fake_stream,
)
elif optional_params.get("stream", False):
elif stream is True and fake_stream is False:
return self.streaming(
logging_obj=logging_obj,
headers=headers,
Expand All @@ -514,6 +549,7 @@ def completion( # type: ignore # noqa: PLR0915
client=client,
max_retries=max_retries,
organization=organization,
stream_options=stream_options,
)
else:
if not isinstance(max_retries, int):
Expand Down Expand Up @@ -559,11 +595,21 @@ def completion( # type: ignore # noqa: PLR0915
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
return convert_to_model_response_object(
final_response_obj = convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
_response_headers=headers,
)

if fake_stream is True:
return self.mock_streaming(
response=cast(ModelResponse, final_response_obj),
logging_obj=logging_obj,
model=model,
stream_options=stream_options,
)

return final_response_obj
except openai.UnprocessableEntityError as e:
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
if litellm.drop_params is True or drop_params is True:
Expand Down Expand Up @@ -625,6 +671,7 @@ def completion( # type: ignore # noqa: PLR0915
async def acompletion(
self,
data: dict,
model: str,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
timeout: Union[float, httpx.Timeout],
Expand All @@ -635,6 +682,8 @@ async def acompletion(
max_retries=None,
headers=None,
drop_params: Optional[bool] = None,
stream_options: Optional[dict] = None,
fake_stream: bool = False,
):
response = None
for _ in range(
Expand Down Expand Up @@ -676,12 +725,22 @@ async def acompletion(
additional_args={"complete_input_dict": data},
)
logging_obj.model_call_details["response_headers"] = headers
return convert_to_model_response_object(
final_response_obj = convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
hidden_params={"headers": headers},
_response_headers=headers,
)

if fake_stream is True:
return self.mock_streaming(
response=cast(ModelResponse, final_response_obj),
logging_obj=logging_obj,
model=model,
stream_options=stream_options,
)

return final_response_obj
except openai.UnprocessableEntityError as e:
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
if litellm.drop_params is True or drop_params is True:
Expand Down Expand Up @@ -712,7 +771,11 @@ def streaming(
client=None,
max_retries=None,
headers=None,
stream_options: Optional[dict] = None,
):
data["stream"] = True
if stream_options is not None:
data["stream_options"] = stream_options
openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False,
api_key=api_key,
Expand Down Expand Up @@ -763,8 +826,12 @@ async def async_streaming(
max_retries=None,
headers=None,
drop_params: Optional[bool] = None,
stream_options: Optional[dict] = None,
):
response = None
data["stream"] = True
if stream_options is not None:
data["stream_options"] = stream_options
for _ in range(2):
try:
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
Expand Down
27 changes: 0 additions & 27 deletions tests/llm_translation/test_openai_o1.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,33 +65,6 @@ async def test_o1_handle_system_role(model):
]


@pytest.mark.parametrize(
"model, expected_streaming_support",
[("o1-preview", True), ("o1-mini", True), ("o1", False)],
)
@pytest.mark.asyncio
async def test_o1_handle_streaming_optional_params(model, expected_streaming_support):
"""
Tests that:
- max_tokens is translated to 'max_completion_tokens'
- role 'system' is translated to 'user'
"""
from openai import AsyncOpenAI
from litellm.utils import ProviderConfigManager
from litellm.types.utils import LlmProviders

os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")

config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders.OPENAI
)

supported_params = config.get_supported_openai_params(model=model)

assert expected_streaming_support == ("stream" in supported_params)


@pytest.mark.parametrize(
"model, expected_tool_calling_support",
[("o1-preview", False), ("o1-mini", False), ("o1", True)],
Expand Down
9 changes: 5 additions & 4 deletions tests/local_testing/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2068,10 +2068,11 @@ def test_openai_chat_completion_complete_response_call():
@pytest.mark.parametrize(
"model",
[
# "gpt-3.5-turbo",
# "azure/chatgpt-v-2",
# "claude-3-haiku-20240307",
# "o1-preview",
"gpt-3.5-turbo",
"azure/chatgpt-v-2",
"claude-3-haiku-20240307",
"o1-preview",
"o1",
"azure/fake-o1-mini",
],
)
Expand Down

0 comments on commit a702326

Please sign in to comment.