-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] Add Azure-hosted Dalle Image Generation Support #2586
Changes from 3 commits
4afab28
f5b6fd8
d3fd3e2
3856d39
db37900
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
import re | ||
from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Union | ||
|
||
from openai import OpenAI | ||
from openai import AzureOpenAI, OpenAI | ||
from PIL.Image import Image | ||
|
||
from autogen import Agent, ConversableAgent, code_utils | ||
|
@@ -18,13 +18,16 @@ | |
EXAMPLE: Blue background, 3D shapes, ... | ||
""" | ||
|
||
VALID_DALLE_MODELS = ["dall-e-2", "dall-e-3"] | ||
|
||
|
||
class ImageGenerator(Protocol): | ||
"""This class defines an interface for image generators. | ||
|
||
Concrete implementations of this protocol must provide a `generate_image` method that takes a string prompt as | ||
input and returns a PIL Image object. | ||
|
||
NOTE: Only OpenAI's and Azure's DALL-E model are supported. | ||
NOTE: Current implementation does not allow you to edit a previously existing image. | ||
""" | ||
|
||
|
@@ -80,14 +83,17 @@ def __init__( | |
num_images (int): The number of images to generate. | ||
""" | ||
config_list = llm_config["config_list"] | ||
_validate_dalle_model(config_list[0]["model"]) | ||
dalle_configs = _find_valid_dalle_config(config_list) | ||
assert len(dalle_configs) > 0, "Invalid DALL-E config. Must contain a valid DALL-E model." | ||
|
||
_validate_dalle_model(dalle_configs[0]["model"]) | ||
_validate_resolution_format(resolution) | ||
|
||
self._model = config_list[0]["model"] | ||
self._model = dalle_configs[0]["model"] | ||
self._resolution = resolution | ||
self._quality = quality | ||
self._num_images = num_images | ||
self._dalle_client = OpenAI(api_key=config_list[0]["api_key"]) | ||
self._dalle_client = self._dalle_client_factory(dalle_configs[0]) | ||
|
||
def generate_image(self, prompt: str) -> Image: | ||
response = self._dalle_client.images.generate( | ||
|
@@ -108,6 +114,12 @@ def cache_key(self, prompt: str) -> str: | |
keys = (prompt, self._model, self._resolution, self._quality, self._num_images) | ||
return ",".join([str(k) for k in keys]) | ||
|
||
def _dalle_client_factory(self, dalle_config: Dict) -> Union[OpenAI, AzureOpenAI]: | ||
if dalle_config.get("api_type") == "azure": | ||
return AzureOpenAI(api_key=dalle_config["api_key"]) | ||
else: | ||
return OpenAI(api_key=dalle_config["api_key"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For OpenAI, I'm not sure if more keys are possible. Why do we pass the |
||
|
||
|
||
class ImageGeneration(AgentCapability): | ||
"""This capability allows a ConversableAgent to generate images based on the message received from other Agents. | ||
|
@@ -287,5 +299,9 @@ def _validate_resolution_format(resolution: str): | |
|
||
|
||
def _validate_dalle_model(model: str): | ||
if model not in ["dall-e-3", "dall-e-2"]: | ||
raise ValueError(f"Invalid DALL-E model: {model}. Must be 'dall-e-3' or 'dall-e-2'") | ||
if model not in VALID_DALLE_MODELS: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For Azure models, they are not necessarily named as one of VALID_DALLE_MODELS. They can be named as anything. |
||
raise ValueError(f"Invalid DALL-E model: {model}. Must be in {VALID_DALLE_MODELS}") | ||
|
||
|
||
def _find_valid_dalle_config(config_list: List[Dict]) -> List[Dict]: | ||
return list(filter(lambda config: config["model"] in VALID_DALLE_MODELS, config_list)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AzureOpenAI requires more fields than api_key.