Skip to content

Commit

Permalink
Huggingface retry generate_image with delay (Significant-Gravitas#2745)
Browse files Browse the repository at this point in the history
Co-authored-by: Media <[email protected]>
Co-authored-by: Nicholas Tindle <[email protected]>
Co-authored-by: Nicholas Tindle <[email protected]>
Co-authored-by: k-boikov <[email protected]>
Co-authored-by: merwanehamadi <[email protected]>
Co-authored-by; lc0rp
  • Loading branch information
primaryobjects authored May 16, 2023
1 parent c1cd54d commit f424fac
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 17 deletions.
50 changes: 37 additions & 13 deletions autogpt/commands/image_gen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
""" Image Generation Module for AutoGPT."""
import io
import json
import time
import uuid
from base64 import b64decode

Expand Down Expand Up @@ -61,20 +63,42 @@ def generate_image_with_hf(prompt: str, filename: str) -> str:
"X-Use-Cache": "false",
}

response = requests.post(
API_URL,
headers=headers,
json={
"inputs": prompt,
},
)

image = Image.open(io.BytesIO(response.content))
logger.info(f"Image Generated for prompt:{prompt}")

image.save(filename)
retry_count = 0
while retry_count < 10:
response = requests.post(
API_URL,
headers=headers,
json={
"inputs": prompt,
},
)

return f"Saved to disk:{filename}"
if response.ok:
try:
image = Image.open(io.BytesIO(response.content))
logger.info(f"Image Generated for prompt:{prompt}")
image.save(filename)
return f"Saved to disk:{filename}"
except Exception as e:
logger.error(e)
break
else:
try:
error = json.loads(response.text)
if "estimated_time" in error:
delay = error["estimated_time"]
logger.debug(response.text)
logger.info("Retrying in", delay)
time.sleep(delay)
else:
break
except Exception as e:
logger.error(e)
break

retry_count += 1

return f"Error creating image."


def generate_image_with_dalle(prompt: str, filename: str, size: int) -> str:
Expand Down
112 changes: 108 additions & 4 deletions tests/test_image_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from PIL import Image

from autogpt.commands.image_gen import generate_image, generate_image_with_sd_webui
from autogpt.config import Config
from tests.utils import requires_api_key


Expand All @@ -19,7 +20,7 @@ def image_size(request):
reason="The image is too big to be put in a cassette for a CI pipeline. We're looking into a solution."
)
@requires_api_key("OPENAI_API_KEY")
def test_dalle(config, workspace, image_size, patched_api_requestor):
def test_dalle(config, workspace, image_size):
"""Test DALL-E image generation."""
generate_and_validate(
config,
Expand Down Expand Up @@ -48,18 +49,18 @@ def test_huggingface(config, workspace, image_size, image_model):
)


@pytest.mark.skip(reason="External SD WebUI may not be available.")
@pytest.mark.xfail(reason="SD WebUI call does not work.")
def test_sd_webui(config, workspace, image_size):
"""Test SD WebUI image generation."""
generate_and_validate(
config,
workspace,
image_provider="sdwebui",
image_provider="sd_webui",
image_size=image_size,
)


@pytest.mark.skip(reason="External SD WebUI may not be available.")
@pytest.mark.xfail(reason="SD WebUI call does not work.")
def test_sd_webui_negative_prompt(config, workspace, image_size):
gen_image = functools.partial(
generate_image_with_sd_webui,
Expand Down Expand Up @@ -103,3 +104,106 @@ def generate_and_validate(
assert image_path.exists()
with Image.open(image_path) as img:
assert img.size == (image_size, image_size)


def test_huggingface_fail_request_with_delay(mocker):
config = Config()
config.huggingface_api_token = "1"

# Mock requests.post
mock_post = mocker.patch("requests.post")
mock_post.return_value.status_code = 500
mock_post.return_value.ok = False
mock_post.return_value.text = '{"error":"Model CompVis/stable-diffusion-v1-4 is currently loading","estimated_time":0}'

# Mock time.sleep
mock_sleep = mocker.patch("time.sleep")

config.image_provider = "huggingface"
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"

result = generate_image("astronaut riding a horse", 512)

assert result == "Error creating image."

# Verify retry was called with delay.
mock_sleep.assert_called_with(0)


def test_huggingface_fail_request_no_delay(mocker):
config = Config()
config.huggingface_api_token = "1"

# Mock requests.post
mock_post = mocker.patch("requests.post")
mock_post.return_value.status_code = 500
mock_post.return_value.ok = False
mock_post.return_value.text = (
'{"error":"Model CompVis/stable-diffusion-v1-4 is currently loading"}'
)

# Mock time.sleep
mock_sleep = mocker.patch("time.sleep")

config.image_provider = "huggingface"
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"

result = generate_image("astronaut riding a horse", 512)

assert result == "Error creating image."

# Verify retry was not called.
mock_sleep.assert_not_called()


def test_huggingface_fail_request_bad_json(mocker):
config = Config()
config.huggingface_api_token = "1"

# Mock requests.post
mock_post = mocker.patch("requests.post")
mock_post.return_value.status_code = 500
mock_post.return_value.ok = False
mock_post.return_value.text = '{"error:}'

# Mock time.sleep
mock_sleep = mocker.patch("time.sleep")

config.image_provider = "huggingface"
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"

result = generate_image("astronaut riding a horse", 512)

assert result == "Error creating image."

# Verify retry was not called.
mock_sleep.assert_not_called()


def test_huggingface_fail_request_bad_image(mocker):
config = Config()
config.huggingface_api_token = "1"

# Mock requests.post
mock_post = mocker.patch("requests.post")
mock_post.return_value.status_code = 200

config.image_provider = "huggingface"
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"

result = generate_image("astronaut riding a horse", 512)

assert result == "Error creating image."


def test_huggingface_fail_missing_api_token(mocker):
config = Config()
config.image_provider = "huggingface"
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"

# Mock requests.post to raise ValueError
mock_post = mocker.patch("requests.post", side_effect=ValueError)

# Verify request raises an error.
with pytest.raises(ValueError):
generate_image("astronaut riding a horse", 512)

0 comments on commit f424fac

Please sign in to comment.