Skip to content

Commit

Permalink
chore: Update API key handling and introduce ez methods for images li…
Browse files Browse the repository at this point in the history
…ke chat & mod
  • Loading branch information
herumes committed Jun 10, 2024
1 parent 3f15102 commit 709279a
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 5 deletions.
8 changes: 8 additions & 0 deletions shuttleai/client/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def __init__(
):
super().__init__(base_url, api_key, timeout)

if self.api_key is None:
raise ShuttleAIException(
"API key not provided. Please set SHUTTLEAI_API_KEY environment variable.\n"
+ "Alternatively, you may pass it as an argument to the ShuttleAI class constructor as `api_key`.\n"
+ "In addition to that, you can also set the api key after creating the ShuttleAI object by setting \
`client.api_key`.\n"
)

self._timeout = timeout if isinstance(timeout, ClientTimeout) else ClientTimeout(total=timeout)
if default_headers:
self.default_headers = default_headers
Expand Down
11 changes: 10 additions & 1 deletion shuttleai/client/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,22 @@ def __init__(
http_client: Optional[Client] = None,
):
super().__init__(base_url, api_key, timeout)

if self.api_key is None:
raise ShuttleAIException(
"API key not provided. Please set SHUTTLEAI_API_KEY environment variable.\n"
+ "Alternatively, you may pass it as an argument to the ShuttleAI class constructor as `api_key`.\n"
+ "In addition to that, you can also set the api key after creating the ShuttleAI object by setting \
`client.api_key`.\n"
)

if default_headers:
self.default_headers = default_headers

if http_client:
self._http_client = http_client
else:
self._http_client = Client(follow_redirects=True, timeout=timeout)
self._http_client = Client(timeout=timeout)

self.chat: resources.Chat = resources.Chat(self)
self.images: resources.Images = resources.Images(self)
Expand Down
2 changes: 0 additions & 2 deletions shuttleai/resources/etc/insults.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ async def generate(
self,
model: Optional[Union[str, Literal["insult-1"]]] = "insult-1",
) -> InsultResponse:

return await self.handle_request( # type: ignore
method="get",
endpoint=f"v1/insults?key={self._client.api_key}&model={model}",
Expand All @@ -23,7 +22,6 @@ def generate(
self,
model: Optional[Union[str, Literal["insult-1"]]] = "insult-1",
) -> InsultResponse:

return self.handle_request( # type: ignore
method="get",
endpoint=f"v1/insults?key={self._client.api_key}&model={model}",
Expand Down
2 changes: 0 additions & 2 deletions shuttleai/resources/etc/jokes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ async def generate(
self,
model: Optional[Union[str, Literal["joke-1"]]] = "joke-1",
) -> JokeResponse:

return await self.handle_request( # type: ignore
method="get",
endpoint=f"v1/jokes?key={self._client.api_key}&model={model}",
Expand All @@ -23,7 +22,6 @@ def generate(
self,
model: Optional[Union[str, Literal["joke-1"]]] = "joke-1",
) -> JokeResponse:

return self.handle_request( # type: ignore
method="get",
endpoint=f"v1/jokes?key={self._client.api_key}&model={model}",
Expand Down
39 changes: 39 additions & 0 deletions shuttleai/schemas/images/generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,49 @@ class Image(BaseModel):
url: str
"""The URL of the image."""

def to_file(self, path: str) -> None:
"""Save the image to a file.
Args:
path (str): The path to save the image to.
"""
import httpx

response = httpx.get(self.url)
with open(path, "wb") as file:
file.write(response.content)

def to_bytes(self) -> bytes:
"""Get the image as bytes.
Returns:
bytes: The image as bytes.
"""
import httpx

response = httpx.get(self.url)
return response.content

def show(self) -> None:
"""Show the image using pillow."""
from io import BytesIO

from PIL import Image as PILImage

image = PILImage.open(BytesIO(self.to_bytes()))
image.show()

def __str__(self) -> str:
return self.url


class ImagesGenerationResponse(BaseModel):
created: int
"""The Unix timestamp when the image generation was created."""

data: List[Image]
"""The generated image(s)."""

@property
def first_image(self) -> Image:
return self.data[0]

0 comments on commit 709279a

Please sign in to comment.