Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Add Azure-hosted Dalle Image Generation Support #2586

Closed
wants to merge 5 commits into from
Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions autogen/agentchat/contrib/capabilities/generate_images.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Union

from openai import OpenAI
from openai import AzureOpenAI, OpenAI
from PIL.Image import Image

from autogen import Agent, ConversableAgent, code_utils
Expand All @@ -18,13 +18,16 @@
EXAMPLE: Blue background, 3D shapes, ...
"""

VALID_DALLE_MODELS = ["dall-e-2", "dall-e-3"]


class ImageGenerator(Protocol):
"""This class defines an interface for image generators.

Concrete implementations of this protocol must provide a `generate_image` method that takes a string prompt as
input and returns a PIL Image object.

NOTE: Only OpenAI's and Azure's DALL-E model are supported.
NOTE: Current implementation does not allow you to edit a previously existing image.
"""

Expand Down Expand Up @@ -80,14 +83,17 @@ def __init__(
num_images (int): The number of images to generate.
"""
config_list = llm_config["config_list"]
_validate_dalle_model(config_list[0]["model"])
dalle_configs = _find_valid_dalle_config(config_list)
assert len(dalle_configs) > 0, "Invalid DALL-E config. Must contain a valid DALL-E model."

_validate_dalle_model(dalle_configs[0]["model"])
_validate_resolution_format(resolution)

self._model = config_list[0]["model"]
self._model = dalle_configs[0]["model"]
self._resolution = resolution
self._quality = quality
self._num_images = num_images
self._dalle_client = OpenAI(api_key=config_list[0]["api_key"])
self._dalle_client = self._dalle_client_factory(dalle_configs[0])

def generate_image(self, prompt: str) -> Image:
response = self._dalle_client.images.generate(
Expand All @@ -108,6 +114,12 @@ def cache_key(self, prompt: str) -> str:
keys = (prompt, self._model, self._resolution, self._quality, self._num_images)
return ",".join([str(k) for k in keys])

def _dalle_client_factory(self, dalle_config: Dict) -> Union[OpenAI, AzureOpenAI]:
if dalle_config.get("api_type") == "azure":
return AzureOpenAI(api_key=dalle_config["api_key"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AzureOpenAI requires more fields than api_key.

else:
return OpenAI(api_key=dalle_config["api_key"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For OpenAI, I'm not sure if more keys are possible. Why do we pass the api_key only here?



class ImageGeneration(AgentCapability):
"""This capability allows a ConversableAgent to generate images based on the message received from other Agents.
Expand Down Expand Up @@ -287,5 +299,9 @@ def _validate_resolution_format(resolution: str):


def _validate_dalle_model(model: str):
if model not in ["dall-e-3", "dall-e-2"]:
raise ValueError(f"Invalid DALL-E model: {model}. Must be 'dall-e-3' or 'dall-e-2'")
if model not in VALID_DALLE_MODELS:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Azure models, they are not necessarily named as one of VALID_DALLE_MODELS. They can be named as anything.

raise ValueError(f"Invalid DALL-E model: {model}. Must be in {VALID_DALLE_MODELS}")


def _find_valid_dalle_config(config_list: List[Dict]) -> List[Dict]:
return list(filter(lambda config: config["model"] in VALID_DALLE_MODELS, config_list))
Loading