-
Notifications
You must be signed in to change notification settings - Fork 3
/
validation.py
204 lines (166 loc) · 5.72 KB
/
validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from __future__ import annotations
import jax
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
import numpy as np
from PIL import Image
from diffusers import (
FlaxAutoencoderKL,
FlaxDPMSolverMultistepScheduler,
FlaxUNet2DConditionModel,
)
from transformers import ByT5Tokenizer
from architecture import setup_model
# TODO: try half-precision
tokenized_prompt_max_length = 1024
def tokenize_prompts(prompt: list[str]):
return ByT5Tokenizer()(
text=prompt,
max_length=tokenized_prompt_max_length,
padding="max_length",
truncation=True,
return_tensors="jax",
).input_ids
def convert_images(images: jnp.ndarray):
# create PIL image from JAX tensor converted to numpy
return [Image.fromarray(np.asarray(image), mode="RGB") for image in images]
def get_validation_predictions_lambda(
vae: FlaxAutoencoderKL,
vae_params,
unet: FlaxUNet2DConditionModel,
):
scheduler = FlaxDPMSolverMultistepScheduler.from_config(
config={
"_diffusers_version": "0.16.0",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": False,
"num_train_timesteps": 1000,
"prediction_type": "v_prediction",
"set_alpha_to_one": False,
"skip_prk_steps": True,
"steps_offset": 1,
"trained_betas": None,
}
)
timesteps = 20
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
image_width = image_height = 256
# Generating latent shape
latent_shape = (
1536,
unet.in_channels,
image_width // vae_scale_factor,
image_height // vae_scale_factor,
)
def __predict_images(seed, unet_params, encoded_prompts):
def ___timestep(step, step_args):
latents, scheduler_state = step_args
t = jnp.asarray(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep = jnp.array(jnp.broadcast_to(t, latents.shape[0]), dtype=jnp.int32)
scaled_latent_input = jnp.array(
scheduler.scale_model_input(scheduler_state, latents, t)
)
# predict the noise residual
unet_prediction_sample = unet.apply(
{"params": unet_params},
sample=scaled_latent_input,
timesteps=timestep,
encoder_hidden_states=encoded_prompts,
train=False,
).sample
# compute the previous noisy sample x_t -> x_t-1
return scheduler.step(
scheduler_state, unet_prediction_sample, t, latents
).to_tuple()
# initialize scheduler state
initial_scheduler_state = scheduler.set_timesteps(
scheduler.create_state(), num_inference_steps=timesteps, shape=latent_shape
)
# initialize latents
initial_latents = (
jax.random.normal(
jax.random.PRNGKey(seed), shape=latent_shape, dtype=jnp.float32
)
* initial_scheduler_state.init_noise_sigma
)
# get denoises latents
final_latents, _ = jax.lax.fori_loop(
0, timesteps, ___timestep, (initial_latents, initial_scheduler_state)
)
# scale latents
scaled_final_latents = 1 / vae.config.scaling_factor * final_latents
# get image from latents
vae_output = vae.apply(
{"params": vae_params},
latents=scaled_final_latents,
deterministic=True,
method=vae.decode,
).sample
# return 8 bit RGB image (width, height, rgb)
return (
(
(vae_output / 2 + 0.5) # TODO: find out why this is necessary
.transpose(
0, 2, 3, 1
) # (batch, channel, height, width) => (batch, height, width, channel)
.clip(0, 1)
* 255
)
.round()
.astype(jnp.uint8)
)
return lambda seed, unet_params, encoded_prompts: __predict_images(
seed, unet_params, encoded_prompts
)
if __name__ == "__main__":
# Pretrained/freezed and training model setup
text_encoder, text_encoder_params, vae, vae_params, unet, unet_params = setup_model(
43, # seed
None, # dtype (defaults to float32)
True, # load pre-trained
"character-aware-diffusion/charred",
None,
)
# validation prompts
validation_prompts = [
"a white car",
"une voiture blanche",
"a running shoe",
"une chaussure de course",
"a perfumer and his perfume organ",
"un parfumeur et son orgue à parfums",
"two people",
"deux personnes",
"a happy cartoon cat",
"un dessin de chat heureux",
"a city skyline",
"un panorama urbain",
"a Marilyn Monroe portrait",
"un portrait de Marilyn Monroe",
"a rainy day in London",
"Londres sous la pluie",
]
tokenized_prompts = tokenize_prompts(validation_prompts)
encoded_prompts = text_encoder(
tokenized_prompts,
params=text_encoder_params,
train=False,
)[0]
validation_predictions_lambda = get_validation_predictions_lambda(
vae,
vae_params,
unet,
)
get_validation_predictions = jax.pmap(
fun=validation_predictions_lambda,
axis_name="encoded_prompts",
donate_argnums=(),
)
image_predictions = get_validation_predictions(
replicate(2), replicate(unet_params), shard(encoded_prompts)
)
images = convert_images(image_predictions)