Skip to content

Commit

Permalink
Merge pull request #42 from minimaxir:pydantic2.0
Browse files Browse the repository at this point in the history
Migrate to Pydantic 2.0
  • Loading branch information
minimaxir authored Jul 3, 2023
2 parents 7cedbed + bcd6624 commit 26f0036
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 35 deletions.
2 changes: 1 addition & 1 deletion examples/notebooks/schema_ttrpg.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@
"metadata": {},
"outputs": [],
"source": [
"input_ttrpg = write_ttrpg_setting.parse_obj(response_structured)"
"input_ttrpg = write_ttrpg_setting.model_validate(response_structured)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="simpleaichat",
packages=["simpleaichat"], # this must be the same as the name above
version="0.2.0",
version="0.2.1",
description="A Python package for easily interfacing with chat apps, with robust features and minimal code complexity.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
Expand All @@ -18,7 +18,7 @@
},
python_requires=">=3.8",
install_requires=[
"pydantic>=1.10",
"pydantic>=2.0",
"fire>=0.3.0",
"httpx>=0.24.1",
"python-dotenv>=1.0.0",
Expand Down
17 changes: 13 additions & 4 deletions simpleaichat/chatgpt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from pydantic import HttpUrl
from pydantic import HttpUrl, ConfigDict
from httpx import Client, AsyncClient
from typing import List, Dict, Union, Set, Any
import orjson

from .models import ChatMessage, ChatSession
from .utils import remove_a_key

tool_prompt = """From the list of tools below:
- Reply ONLY with the number of the tool appropriate in response to the user's last message.
Expand All @@ -17,6 +18,7 @@ class ChatGPTSession(ChatSession):
input_fields: Set[str] = {"role", "content", "name"}
system: str = "You are a helpful assistant."
params: Dict[str, Any] = {"temperature": 0.7}
model_config: ConfigDict(arbitrary_types_allowed=True)

def prepare_request(
self,
Expand All @@ -41,7 +43,9 @@ def prepare_request(
prompt, input_schema
), f"prompt must be an instance of {input_schema.__name__}"
user_message = ChatMessage(
role="function", content=prompt.json(), name=input_schema.__name__
role="function",
content=prompt.model_dump_json(),
name=input_schema.__name__,
)

gen_params = params or self.params
Expand All @@ -60,7 +64,9 @@ def prepare_request(
functions.append(input_function)
if output_schema:
output_function = self.schema_to_function(output_schema)
functions.append(output_function) if output_function not in functions else None
functions.append(
output_function
) if output_function not in functions else None
if is_function_calling_required:
data["function_call"] = {"name": output_schema.__name__}
data["functions"] = functions
Expand All @@ -69,10 +75,13 @@ def prepare_request(

def schema_to_function(self, schema: Any):
assert schema.__doc__, f"{schema.__name__} is missing a docstring."
schema_dict = schema.model_json_schema()
remove_a_key(schema_dict, "title")

return {
"name": schema.__name__,
"description": schema.__doc__,
"parameters": schema.schema(),
"parameters": schema_dict,
}

def gen(
Expand Down
28 changes: 10 additions & 18 deletions simpleaichat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,16 @@ def now_tz():
class ChatMessage(BaseModel):
role: str
content: str
name: Optional[str]
function_call: Optional[str]
name: Optional[str] = None
function_call: Optional[str] = None
received_at: datetime.datetime = Field(default_factory=now_tz)
finish_reason: Optional[str]
prompt_length: Optional[int]
completion_length: Optional[int]
total_length: Optional[int]

class Config:
json_loads = orjson.loads
json_dumps = orjson_dumps
finish_reason: Optional[str] = None
prompt_length: Optional[int] = None
completion_length: Optional[int] = None
total_length: Optional[int] = None

def __str__(self) -> str:
return str(self.dict(exclude_none=True))
return str(self.model_dump(exclude_none=True))


class ChatSession(BaseModel):
Expand All @@ -53,10 +49,6 @@ class ChatSession(BaseModel):
total_length: int = 0
title: Optional[str] = None

class Config:
json_loads = orjson.loads
json_dumps = orjson_dumps

def __str__(self) -> str:
sess_start_str = self.created_at.strftime("%Y-%m-%d %H:%M:%S")
last_message_str = self.messages[-1].received_at.strftime("%Y-%m-%d %H:%M:%S")
Expand All @@ -73,12 +65,12 @@ def format_input_messages(
else self.messages
)
return (
[system_message.dict(include=self.input_fields, exclude_none=True)]
[system_message.model_dump(include=self.input_fields, exclude_none=True)]
+ [
m.dict(include=self.input_fields, exclude_none=True)
m.model_dump(include=self.input_fields, exclude_none=True)
for m in recent_messages
]
+ [user_message.dict(include=self.input_fields, exclude_none=True)]
+ [user_message.model_dump(include=self.input_fields, exclude_none=True)]
)

def add_messages(
Expand Down
16 changes: 6 additions & 10 deletions simpleaichat/simpleaichat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,25 @@
from contextlib import contextmanager, asynccontextmanager
import csv

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from httpx import Client, AsyncClient
from typing import List, Dict, Union, Optional, Any
import orjson
from dotenv import load_dotenv
from rich.console import Console

from .utils import wikipedia_search_lookup
from .models import ChatMessage, ChatSession, orjson_dumps
from .models import ChatMessage, ChatSession
from .chatgpt import ChatGPTSession

load_dotenv()


class AIChat(BaseModel):
client: Union[Client, AsyncClient]
client: Any
default_session: Optional[ChatSession]
sessions: Dict[Union[str, UUID], ChatSession] = {}

class Config:
arbitrary_types_allowed = True
json_loads = orjson.loads
json_dumps = orjson_dumps
model_config: ConfigDict(arbitrary_types_allowed=True)

def __init__(
self,
Expand Down Expand Up @@ -222,7 +218,7 @@ def interactive_console(self, character: str = None, prime: bool = True) -> None

def __str__(self) -> str:
if self.default_session:
return self.default_session.json(
return self.default_session.model_dump_json(
exclude={"api_key", "api_url"},
exclude_none=True,
option=orjson.OPT_INDENT_2,
Expand All @@ -240,7 +236,7 @@ def save_session(
minify: bool = False,
):
sess = self.get_session(id)
sess_dict = sess.dict(
sess_dict = sess.model_dump(
exclude={"auth", "api_url", "input_fields"},
exclude_none=True,
)
Expand Down
10 changes: 10 additions & 0 deletions simpleaichat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,13 @@ async def wikipedia_search_lookup_async(query: str, sentences: int = 1) -> str:

def fd(description: str):
return Field(description=description)


# https://stackoverflow.com/a/58938747
def remove_a_key(d, remove_key):
if isinstance(d, dict):
for key in list(d.keys()):
if key == remove_key:
del d[key]
else:
remove_a_key(d[key], remove_key)

0 comments on commit 26f0036

Please sign in to comment.