Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for ollama structured outputs #7344

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion litellm/llms/ollama/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Union
from typing import Optional, Union

import httpx

from litellm._logging import verbose_logger
from litellm.llms.base_llm.chat.transformation import BaseLLMException


Expand Down Expand Up @@ -43,3 +44,26 @@ def _convert_image(image):
image_data.convert("RGB").save(jpeg_image, "JPEG")
jpeg_image.seek(0)
return base64.b64encode(jpeg_image.getvalue()).decode("utf-8")


def process_response_format(response_format: dict) -> Optional[str]:
"""
Process OpenAI-style response format specification into Ollama API format
string

Args:
response_format (dict): OpenAI-style response format specification

Returns:
str: Format string for Ollama API
"""
format_type = response_format.get("type")
if format_type == "json_object":
return "json"
elif format_type == "json_schema":
schema = response_format.get("json_schema", {}).get("schema")
if not schema:
raise ValueError("Invalid JSON schema format")
return schema
else:
verbose_logger.warning(f"Unsupported response format type: {format_type}")
7 changes: 4 additions & 3 deletions litellm/llms/ollama/completion/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
ProviderField,
)

from ..common_utils import OllamaError, _convert_image
from ..common_utils import OllamaError, _convert_image, process_response_format

if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
Expand Down Expand Up @@ -172,8 +172,9 @@ def map_openai_params(
if param == "stop":
optional_params["stop"] = value
if param == "response_format" and isinstance(value, dict):
if value["type"] == "json_object":
optional_params["format"] = "json"
format = process_response_format(value)
if format is not None:
optional_params["format"] = format

return optional_params

Expand Down
7 changes: 5 additions & 2 deletions litellm/llms/ollama_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import litellm
from litellm import verbose_logger
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.llms.ollama.common_utils import process_response_format
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
Expand Down Expand Up @@ -152,8 +153,10 @@ def map_openai_params(
optional_params["repeat_penalty"] = value
if param == "stop":
optional_params["stop"] = value
if param == "response_format" and value["type"] == "json_object":
optional_params["format"] = "json"
if param == "response_format" and isinstance(value, dict):
format = process_response_format(value)
if format is not None:
optional_params["format"] = format
### FUNCTION CALLING LOGIC ###
if param == "tools":
# ollama actually supports json output
Expand Down
38 changes: 33 additions & 5 deletions tests/local_testing/test_ollama.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import asyncio
import json
import os
import sys
import traceback

from dotenv import load_dotenv
from pydantic import BaseModel

load_dotenv()
import io
import os

sys.path.insert(
0, os.path.abspath("../..")
Expand All @@ -19,7 +18,12 @@
import litellm

## for ollama we can't test making the completion call
from litellm.utils import EmbeddingResponse, get_llm_provider, get_optional_params
from litellm.utils import (
EmbeddingResponse,
get_llm_provider,
get_optional_params,
type_to_response_format_param,
)


def test_get_ollama_params():
Expand Down Expand Up @@ -76,6 +80,31 @@ def test_ollama_json_mode():
# test_ollama_json_mode()


def test_ollama_structured_format():
# assert that format receives a structred json schema
class Country(BaseModel):
name: str
code: str
languages: list[str]

model_schema = type_to_response_format_param(Country) or {}
model_schema = model_schema.get("json_schema", {}).get("schema", {})
try:
converted_params = get_optional_params(
custom_llm_provider="ollama",
model="llama2",
temperature=0.5,
response_format=Country,
)
assert converted_params == {
"temperature": 0.5,
"format": model_schema,
"stream": False,
}, f"{converted_params} != {{'temperature': 0.5, 'format': {model_schema}, 'stream': False}}"
except Exception as e:
pytest.fail(f"Error occurred: {e}")


mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text")


Expand Down Expand Up @@ -135,7 +164,6 @@ def test_ollama_aembeddings(mock_aembeddings):

@pytest.mark.skip(reason="local only test")
def test_ollama_chat_function_calling():
import json

tools = [
{
Expand Down