-
Notifications
You must be signed in to change notification settings - Fork 47
/
main.py
68 lines (53 loc) · 1.34 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import io
from fastapi import FastAPI
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
from ml import obtain_image
app = FastAPI()
@app.get("/")
def read_root():
return {"Hello": "World"}
@app.get("/items/{item_id}")
def read_item(item_id: int):
return {"item_id": item_id}
class Item(BaseModel):
name: str
price: float
tags: list[str] = []
@app.post("/items/")
def create_item(item: Item):
return item
@app.get("/generate")
def generate_image(
prompt: str,
*,
seed: int | None = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5
):
image = obtain_image(
prompt,
num_inference_steps=num_inference_steps,
seed=seed,
guidance_scale=guidance_scale,
)
image.save("image.png")
return FileResponse("image.png")
@app.get("/generate-memory")
def generate_image_memory(
prompt: str,
*,
seed: int | None = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5
):
image = obtain_image(
prompt,
num_inference_steps=num_inference_steps,
seed=seed,
guidance_scale=guidance_scale,
)
memory_stream = io.BytesIO()
image.save(memory_stream, format="PNG")
memory_stream.seek(0)
return StreamingResponse(memory_stream, media_type="image/png")