-
Notifications
You must be signed in to change notification settings - Fork 0
/
optimize_text_emb.py
294 lines (221 loc) · 10.9 KB
/
optimize_text_emb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import argparse
import json
import os
from datetime import datetime
from typing import Tuple
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, EMAModel
from evaluation import generate_samples_and_evaluate_blip_vqa
def parse_args():
parser = argparse.ArgumentParser(description="Optimize text embedding for better compositionality")
parser.add_argument(
"--num_chunks",
type=int,
default=20,
)
parser.add_argument(
"--chunk_idx",
type=int,
default=None,
required=True,
)
parser.add_argument(
"--generator_directory_names",
type=str,
nargs='+',
default=["syngen", "deepfloyd", "sd-v1-4"]
)
parser.add_argument(
"--stable_diffusion_checkpoint",
type=str,
default="CompVis/stable-diffusion-v1-4",
choices=["CompVis/stable-diffusion-v1-4"]
)
parser.add_argument(
"--compbench_category_name",
type=str,
default="color",
choices=["color", "texture", "shape"],
)
args = parser.parse_args()
if args.chunk_idx < 0 or args.chunk_idx >= args.num_chunks:
raise ValueError("--chunk_idx should be in range of (0, --num_chunks)")
return args
def get_training_dataloader(prompt_directory_path: str, args: argparse.Namespace) -> DataLoader:
missing_generators = set(args.generator_directory_names) - set(os.listdir(prompt_directory_path))
if len(missing_generators) != 0:
raise Exception(f"Some generators are missing in the directory [{', '.join(missing_generators)}]")
generators_vqa_results = {}
for generator_directory_name in args.generator_directory_names:
vqa_result_path = os.path.join(prompt_directory_path, generator_directory_name, 'vqa_result.json')
if not os.path.isfile(vqa_result_path):
raise Exception(f"VQA Result is missing: \"{vqa_result_path}\"")
with open(vqa_result_path) as f:
d = json.load(f)
for k in d:
generators_vqa_results[os.path.join(generator_directory_name, k)] = float(d[k])
sorted_generators_vqa_results = sorted(generators_vqa_results.items(), key=lambda x: -x[1])
top_generators_vqa_results = sorted_generators_vqa_results[:30]
dataset = load_dataset(
prompt_directory_path, data_files={"train": [x[0] for x in top_generators_vqa_results]}, split='train'
)
preprocess = transforms.Compose(
[
transforms.Resize((512, 512)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def transform(examples):
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
return {"images": images}
dataset.set_transform(transform)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)
return train_dataloader
def get_text_embeddings(prompt: str, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel) -> torch.Tensor:
text_input = tokenizer(
[prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
)
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to('cuda'))[0]
return text_embeddings
def train(
train_dataloader: DataLoader,
vae: AutoencoderKL,
unet: UNet2DConditionModel,
scheduler: PNDMScheduler,
learnable_text_embedding: torch.nn.Parameter,
ema_learnable_text_embedding: EMAModel,
optimizer: torch.optim.Adam,
opt_scheduler: torch.optim.lr_scheduler.LRScheduler
) -> Tuple[torch.nn.Parameter, EMAModel]:
for epoch in range(100):
epoch_loss = 0.
for batch in train_dataloader:
latents = vae.encode(batch["images"].to('cuda')).latent_dist.sample()
latents = latents * vae.config.scaling_factor
noise = torch.randn_like(latents)
batch_size = latents.shape[0]
# timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,), device=latents.device)
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size//2 + 1, ), device=latents.device)
timesteps = torch.cat([timesteps, scheduler.config.num_train_timesteps - timesteps - 1], dim=0)[:batch_size]
timesteps = timesteps.long()
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
target = noise # noise_scheduler.config.prediction_type = "epsilon"
encoder_hidden_states = learnable_text_embedding.expand(batch_size, -1, -1)
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(learnable_text_embedding, 1.);
optimizer.step()
epoch_loss += loss.item() * batch_size
opt_scheduler.step()
ema_learnable_text_embedding.step(learnable_text_embedding)
epoch_loss /= len(train_dataloader.dataset)
print(f"[Epoch {epoch:2d}] Loss: {epoch_loss}", flush=True)
return learnable_text_embedding, ema_learnable_text_embedding
def load_models(args) -> Tuple[AutoencoderKL, CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, PNDMScheduler]:
vae = AutoencoderKL.from_pretrained(args.stable_diffusion_checkpoint, subfolder="vae", use_safetensors=True)
tokenizer = CLIPTokenizer.from_pretrained(args.stable_diffusion_checkpoint, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.stable_diffusion_checkpoint, subfolder="text_encoder", use_safetensors=True)
unet = UNet2DConditionModel.from_pretrained(args.stable_diffusion_checkpoint, subfolder="unet", use_safetensors=True)
scheduler = PNDMScheduler.from_pretrained(args.stable_diffusion_checkpoint, subfolder="scheduler")
vae.to('cuda')
text_encoder.to('cuda')
unet.to('cuda');
num_inference_steps = 25
scheduler.set_timesteps(num_inference_steps)
return vae, tokenizer, text_encoder, unet, scheduler
def get_list_chunk(arr: list, num_chunks: int, chunk_idx: int) -> list:
arr_len = len(arr)
chunk_size = (arr_len + num_chunks - 1) // num_chunks
start_index = chunk_size * chunk_idx
end_index = min((chunk_idx + 1) * chunk_size, arr_len)
print(f"Choosing chunk ({start_index}:{end_index})")
print(f"First item of the chunk: \"{arr[start_index]}\"")
print(f"Last item of the chunk: \"{arr[end_index-1]}\"", flush=True)
return arr[start_index:end_index]
if __name__ == '__main__':
args = parse_args()
dataset_base_path = f"./T2I-CompBench-dataset/{args.compbench_category_name}"
with open(f'T2I-CompBench-dataset/{args.compbench_category_name}.txt', 'r') as f:
prompts = f.read().splitlines()
prompts = [p.strip('.') for p in prompts]
prompts = sorted(set(prompts))
assert len(set(os.listdir(dataset_base_path)).intersection(prompts)) == len(prompts)
prompts_chunk = get_list_chunk(prompts, args.num_chunks, args.chunk_idx)
# Initialization of the models
vae, tokenizer, text_encoder, unet, scheduler = load_models(args)
text_encoder.requires_grad_(False)
vae.requires_grad_(False)
unet.requires_grad_(False);
for prompt in prompts_chunk:
print("="*100)
print(f"[Start of Training] prompt: {prompt}")
print(f"[Data and Time] {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", flush=True)
prompt_directory_path = os.path.join(dataset_base_path, prompt)
if os.path.isfile(os.path.join(prompt_directory_path, 'learnable_text_embedding', 'vqa_result.json')):
print("!"*100)
print(f"!!! Skipping prompt \"{prompt}\"")
print("!!! Training and Evaluation has already been done!")
print("!"*100, flush=True)
continue
# Picking some "good" samples for training
try:
train_dataloader = get_training_dataloader(prompt_directory_path, args)
except Exception as error:
print("!"*100)
print(f"!!! Failed for prompt \"{prompt}\". Error: {error}")
print("!"*100, flush=True)
continue
print(f"[Training Setup] Training Dataset Size: {len(train_dataloader.dataset)}")
print(f"[Training Setup] Batch Size: {train_dataloader.batch_size}")
# Getting the CLIP's text embedding for the prompt
text_embeddings = get_text_embeddings(prompt, tokenizer, text_encoder)
# Making the text embedding learnable + creating the EMA Model
learnable_text_embedding = torch.nn.Parameter(text_embeddings.detach(), requires_grad=True)
ema_learnable_text_embedding = EMAModel(learnable_text_embedding)
ema_learnable_text_embedding.to('cuda')
# Optimizer and Scheduler used for optimizing the learnable text embedding
optimizer = torch.optim.Adam([learnable_text_embedding], lr=1e-1)
opt_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15, 75], gamma=0.1)
# Training the learnable text embedding
learnable_text_embedding, ema_learnable_text_embedding = train(
train_dataloader,
vae,
unet,
scheduler,
learnable_text_embedding,
ema_learnable_text_embedding,
optimizer,
opt_scheduler
)
# Saving the learned text embedding + the EMA version
torch.save(learnable_text_embedding, os.path.join(prompt_directory_path, 'learnable_text_embedding.pth'))
torch.save(ema_learnable_text_embedding.state_dict(), os.path.join(prompt_directory_path, 'ema_learnable_text_embedding.pth'))
print(f"[Start of Evaluation] prompt: {prompt}", flush=True)
# Evaluation of the learned text embedding by generating 100 samples and calculating the BLIP VQA score
_, average_score = generate_samples_and_evaluate_blip_vqa(
vae,
unet,
scheduler,
tokenizer,
text_encoder,
prompt=[prompt],
fixed_text_embeddings=learnable_text_embedding.data.detach(),
evaluation_path=os.path.join(prompt_directory_path, 'learnable_text_embedding'),
batch_size=10,
num_evaluation_images=100,
)
print(f"[Finished Evaluation] Prompt: {prompt}")
print(f"[Finished Evaluation] Average Score: {average_score}")
print(f"[Data and Time] {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"[Finish] prompt: {prompt}")
print("="*100, flush=True)