Skip to content

Commit

Permalink
Merge pull request #504 from PrefectHQ/response-model
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin authored Jul 24, 2023
2 parents a8bc03f + bf32051 commit 6bc5ce2
Showing 1 changed file with 74 additions and 5 deletions.
79 changes: 74 additions & 5 deletions src/marvin/openai/ChatCompletion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import functools

import openai
from pydantic import BaseSettings, Field
from pydantic import BaseSettings, Field, BaseModel

import marvin
import warnings

from typing import Type, Optional, Union, Literal


class ChatCompletionConfig(BaseSettings):
model: str = "gpt-3.5-turbo"
temperature: float = 0
functions: list = Field(default_factory=list)
function_call: Optional[Union[dict[Literal["name"], str], Literal["auto"]]] = None
messages: list = Field(default_factory=list)
api_key: str = Field(
default_factory=lambda: (
Expand All @@ -29,6 +33,13 @@ def merge(self, *args, **kwargs):
return {k: v for k, v in self.__dict__.items() if v != []}


def process_list(lst):
if len(lst) == 1:
return lst[0]
else:
return lst


class ChatCompletion(openai.ChatCompletion):
def __new__(cls, *args, **kwargs):
subclass = type(
Expand All @@ -39,16 +50,74 @@ def __new__(cls, *args, **kwargs):
return subclass

@classmethod
def create(cls, *args, **kwargs):
def create(cls, *args, response_model: Optional[Type[BaseModel]] = None, **kwargs):
config = getattr(cls, "__config__", ChatCompletionConfig())
if response_model is not None:
if kwargs.get("functions"):
warnings.warn("Use of response_model with functions is not supported")
else:
kwargs["functions"] = [
{
"name": "format_response",
"description": "Format the response",
"parameters": response_model.schema(),
}
]
kwargs["function_call"] = {
"name": "format_response",
}
payload = config.merge(**kwargs)
return cls.observer(super(ChatCompletion, cls).create)(*args, **payload)
response = cls.observer(super(ChatCompletion, cls).create)(*args, **payload)
response.to_model = lambda: (
process_list(
list(
map(
lambda x: response_model.parse_raw(
x.message.function_call.arguments
),
response.choices,
)
)
)
)
return response

@classmethod
async def acreate(cls, *args, **kwargs):
async def acreate(
cls, *args, response_model: Optional[Type[BaseModel]] = None, **kwargs
):
config = getattr(cls, "__config__", ChatCompletionConfig())
if response_model is not None:
if kwargs.get("functions"):
warnings.warn("Use of response_model with functions is not supported")
else:
kwargs["functions"] = [
{
"name": "format_response",
"description": "Format the response",
"parameters": response_model.schema(),
}
]
kwargs["function_call"] = {
"name": "format_response",
}
payload = config.merge(**kwargs)
return await cls.observer(super(ChatCompletion, cls).acreate)(*args, **payload)
response = await cls.observer(super(ChatCompletion, cls).acreate)(
*args, **payload
)
response.to_model = lambda: (
process_list(
list(
map(
lambda x: response_model.parse_raw(
x.message.function_call.arguments
),
response.choices,
)
)
)
)
return response

@staticmethod
def observer(func):
Expand Down

0 comments on commit 6bc5ce2

Please sign in to comment.