Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Bug fix) missing model_group field in logs for aspeech call types #7392

Merged
merged 5 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,9 +801,7 @@ async def acompletion(
kwargs["stream"] = stream
kwargs["original_function"] = self._acompletion
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)

request_priority = kwargs.get("priority") or self.default_priority

start_time = time.time()
if request_priority is not None and isinstance(request_priority, int):
response = await self.schedule_acompletion(**kwargs)
Expand Down Expand Up @@ -1422,7 +1420,7 @@ async def aimage_generation(self, prompt: str, model: str, **kwargs):
kwargs["prompt"] = prompt
kwargs["original_function"] = self._aimage_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs)

return response
Expand Down Expand Up @@ -1660,13 +1658,7 @@ async def aspeech(self, model: str, input: str, voice: str, **kwargs):
messages=[{"role": "user", "content": "prompt"}],
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
data = deployment["litellm_params"].copy()
data["model"]
for k, v in self.default_litellm_params.items():
Expand Down Expand Up @@ -1777,7 +1769,7 @@ async def _arealtime(self, model: str, **kwargs):
messages = [{"role": "user", "content": "dummy-text"}]
try:
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)

# pick the one that is available (lowest TPM/RPM)
deployment = await self.async_get_available_deployment(
Expand Down Expand Up @@ -2215,7 +2207,7 @@ async def acreate_file(
kwargs["model"] = model
kwargs["original_function"] = self._acreate_file
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs)

return response
Expand Down Expand Up @@ -2320,7 +2312,7 @@ async def acreate_batch(
kwargs["model"] = model
kwargs["original_function"] = self._acreate_batch
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs)

return response
Expand Down
34 changes: 34 additions & 0 deletions tests/logging_callback_tests/test_datadog.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,37 @@ def test_datadog_static_methods():
expected_custom_tags = "env:production,service:custom-service,version:1.0.0,HOSTNAME:test-host,POD_NAME:pod-123"
print("DataDogLogger._get_datadog_tags()", DataDogLogger._get_datadog_tags())
assert DataDogLogger._get_datadog_tags() == expected_custom_tags


@pytest.mark.asyncio
async def test_datadog_non_serializable_messages():
"""Test logging events with non-JSON-serializable messages"""
dd_logger = DataDogLogger()

# Create payload with non-serializable content
standard_payload = create_standard_logging_payload()
non_serializable_obj = datetime.now() # datetime objects aren't JSON serializable
standard_payload["messages"] = [{"role": "user", "content": non_serializable_obj}]
standard_payload["response"] = {
"choices": [{"message": {"content": non_serializable_obj}}]
}

kwargs = {"standard_logging_object": standard_payload}

# Test payload creation
dd_payload = dd_logger.create_datadog_logging_payload(
kwargs=kwargs,
response_obj=None,
start_time=datetime.now(),
end_time=datetime.now(),
)

# Verify payload can be serialized
assert dd_payload["status"] == DataDogStatus.INFO

# Verify the message can be parsed back to dict
dict_payload = json.loads(dd_payload["message"])

# Check that the non-serializable objects were converted to strings
assert isinstance(dict_payload["messages"][0]["content"], str)
assert isinstance(dict_payload["response"]["choices"][0]["message"]["content"], str)
23 changes: 19 additions & 4 deletions tests/router_unit_tests/test_router_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import sys
import os
import json
import traceback
from typing import Optional
from dotenv import load_dotenv
from fastapi import Request
from datetime import datetime
Expand All @@ -9,6 +11,7 @@
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm import Router, CustomLogger
from litellm.types.utils import StandardLoggingPayload

# Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -76,19 +79,20 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
print("logging a transcript kwargs: ", kwargs)
print("openai client=", kwargs.get("client"))
self.openai_client = kwargs.get("client")
self.standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)

except Exception:
pass


proxy_handler_instance = MyCustomHandler()


# Set litellm.callbacks = [proxy_handler_instance] on the proxy
# need to set litellm.callbacks = [proxy_handler_instance] # on the proxy
@pytest.mark.asyncio
@pytest.mark.flaky(retries=6, delay=10)
async def test_transcription_on_router():
proxy_handler_instance = MyCustomHandler()
litellm.set_verbose = True
litellm.callbacks = [proxy_handler_instance]
print("\n Testing async transcription on router\n")
Expand Down Expand Up @@ -150,7 +154,9 @@ async def test_transcription_on_router():
@pytest.mark.parametrize("mode", ["iterator"]) # "file",
@pytest.mark.asyncio
async def test_audio_speech_router(mode):

litellm.set_verbose = True
test_logger = MyCustomHandler()
litellm.callbacks = [test_logger]
from litellm import Router

client = Router(
Expand Down Expand Up @@ -178,10 +184,19 @@ async def test_audio_speech_router(mode):
optional_params={},
)

await asyncio.sleep(3)

from litellm.llms.openai.openai import HttpxBinaryResponseContent

assert isinstance(response, HttpxBinaryResponseContent)

assert test_logger.standard_logging_object is not None
print(
"standard_logging_object=",
json.dumps(test_logger.standard_logging_object, indent=4),
)
assert test_logger.standard_logging_object["model_group"] == "tts"


@pytest.mark.asyncio()
async def test_rerank_endpoint(model_list):
Expand Down
Loading