diff --git a/litellm/llms/ollama/common_utils.py b/litellm/llms/ollama/common_utils.py index 5cf213950c16..2450eeccf946 100644 --- a/litellm/llms/ollama/common_utils.py +++ b/litellm/llms/ollama/common_utils.py @@ -1,7 +1,8 @@ -from typing import Union +from typing import Optional, Union import httpx +from litellm._logging import verbose_logger from litellm.llms.base_llm.chat.transformation import BaseLLMException @@ -43,3 +44,26 @@ def _convert_image(image): image_data.convert("RGB").save(jpeg_image, "JPEG") jpeg_image.seek(0) return base64.b64encode(jpeg_image.getvalue()).decode("utf-8") + + +def process_response_format(response_format: dict) -> Optional[str]: + """ + Process OpenAI-style response format specification into Ollama API format + string + + Args: + response_format (dict): OpenAI-style response format specification + + Returns: + str: Format string for Ollama API + """ + format_type = response_format.get("type") + if format_type == "json_object": + return "json" + elif format_type == "json_schema": + schema = response_format.get("json_schema", {}).get("schema") + if not schema: + raise ValueError("Invalid JSON schema format") + return schema + else: + verbose_logger.warning(f"Unsupported response format type: {format_type}") diff --git a/litellm/llms/ollama/completion/transformation.py b/litellm/llms/ollama/completion/transformation.py index 52198893219f..33b009bf30d1 100644 --- a/litellm/llms/ollama/completion/transformation.py +++ b/litellm/llms/ollama/completion/transformation.py @@ -22,7 +22,7 @@ ProviderField, ) -from ..common_utils import OllamaError, _convert_image +from ..common_utils import OllamaError, _convert_image, process_response_format if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj @@ -172,8 +172,9 @@ def map_openai_params( if param == "stop": optional_params["stop"] = value if param == "response_format" and isinstance(value, dict): - if value["type"] == "json_object": - optional_params["format"] = "json" + format = process_response_format(value) + if format is not None: + optional_params["format"] = format return optional_params diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py index 5aa26ced46dd..3ce69aa805b6 100644 --- a/litellm/llms/ollama_chat.py +++ b/litellm/llms/ollama_chat.py @@ -10,6 +10,7 @@ import litellm from litellm import verbose_logger from litellm.llms.custom_httpx.http_handler import get_async_httpx_client +from litellm.llms.ollama.common_utils import process_response_format from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction from litellm.types.llms.openai import ChatCompletionAssistantToolCall @@ -152,8 +153,10 @@ def map_openai_params( optional_params["repeat_penalty"] = value if param == "stop": optional_params["stop"] = value - if param == "response_format" and value["type"] == "json_object": - optional_params["format"] = "json" + if param == "response_format" and isinstance(value, dict): + format = process_response_format(value) + if format is not None: + optional_params["format"] = format ### FUNCTION CALLING LOGIC ### if param == "tools": # ollama actually supports json output diff --git a/tests/local_testing/test_ollama.py b/tests/local_testing/test_ollama.py index 2066859091e8..06124061ddbf 100644 --- a/tests/local_testing/test_ollama.py +++ b/tests/local_testing/test_ollama.py @@ -1,13 +1,12 @@ import asyncio +import json import os import sys -import traceback from dotenv import load_dotenv +from pydantic import BaseModel load_dotenv() -import io -import os sys.path.insert( 0, os.path.abspath("../..") @@ -19,7 +18,12 @@ import litellm ## for ollama we can't test making the completion call -from litellm.utils import EmbeddingResponse, get_llm_provider, get_optional_params +from litellm.utils import ( + EmbeddingResponse, + get_llm_provider, + get_optional_params, + type_to_response_format_param, +) def test_get_ollama_params(): @@ -76,6 +80,31 @@ def test_ollama_json_mode(): # test_ollama_json_mode() +def test_ollama_structured_format(): + # assert that format receives a structred json schema + class Country(BaseModel): + name: str + code: str + languages: list[str] + + model_schema = type_to_response_format_param(Country) or {} + model_schema = model_schema.get("json_schema", {}).get("schema", {}) + try: + converted_params = get_optional_params( + custom_llm_provider="ollama", + model="llama2", + temperature=0.5, + response_format=Country, + ) + assert converted_params == { + "temperature": 0.5, + "format": model_schema, + "stream": False, + }, f"{converted_params} != {{'temperature': 0.5, 'format': {model_schema}, 'stream': False}}" + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text") @@ -135,7 +164,6 @@ def test_ollama_aembeddings(mock_aembeddings): @pytest.mark.skip(reason="local only test") def test_ollama_chat_function_calling(): - import json tools = [ {