-
Notifications
You must be signed in to change notification settings - Fork 22
/
util.py
333 lines (289 loc) · 10.4 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
import json
from pathlib import Path
from typing import Literal, Optional
import torch
from modules.autoencoder import AutoEncoder, AutoEncoderParams
from modules.conditioner import HFEmbedder
from modules.flux_model import Flux, FluxParams
from safetensors.torch import load_file as load_sft
try:
from enum import StrEnum
except:
from enum import Enum
class StrEnum(str, Enum):
pass
from pydantic import BaseModel, ConfigDict
from loguru import logger
class ModelVersion(StrEnum):
flux_dev = "flux-dev"
flux_schnell = "flux-schnell"
class QuantizationDtype(StrEnum):
qfloat8 = "qfloat8"
qint2 = "qint2"
qint4 = "qint4"
qint8 = "qint8"
bfloat16 = "bfloat16"
float16 = "float16"
class ModelSpec(BaseModel):
version: ModelVersion
params: FluxParams
ae_params: AutoEncoderParams
ckpt_path: str | None
# Add option to pass in custom clip model
clip_path: str | None = "openai/clip-vit-large-patch14"
ae_path: str | None
repo_id: str | None
repo_flow: str | None
repo_ae: str | None
text_enc_max_length: int = 512
text_enc_path: str | None
text_enc_device: str | torch.device | None = "cuda:0"
ae_device: str | torch.device | None = "cuda:0"
flux_device: str | torch.device | None = "cuda:0"
flow_dtype: str = "float16"
ae_dtype: str = "bfloat16"
text_enc_dtype: str = "bfloat16"
# unused / deprecated
num_to_quant: Optional[int] = 20
quantize_extras: bool = False
compile_extras: bool = False
compile_blocks: bool = False
flow_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8
text_enc_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8
ae_quantization_dtype: Optional[QuantizationDtype] = None
clip_quantization_dtype: Optional[QuantizationDtype] = None
offload_text_encoder: bool = False
offload_vae: bool = False
offload_flow: bool = False
prequantized_flow: bool = False
# Improved precision via not quanitzing the modulation linear layers
quantize_modulation: bool = True
# Improved precision via not quanitzing the flow embedder layers
quantize_flow_embedder_layers: bool = False
model_config: ConfigDict = {
"arbitrary_types_allowed": True,
"use_enum_values": True,
}
def load_models(config: ModelSpec) -> tuple[Flux, AutoEncoder, HFEmbedder, HFEmbedder]:
flow = load_flow_model(config)
ae = load_autoencoder(config)
clip, t5 = load_text_encoders(config)
return flow, ae, clip, t5
def parse_device(device: str | torch.device | None) -> torch.device:
if isinstance(device, str):
return torch.device(device)
elif isinstance(device, torch.device):
return device
else:
return torch.device("cuda:0")
def into_dtype(dtype: str) -> torch.dtype:
if isinstance(dtype, torch.dtype):
return dtype
if dtype == "float16":
return torch.float16
elif dtype == "bfloat16":
return torch.bfloat16
elif dtype == "float32":
return torch.float32
else:
raise ValueError(f"Invalid dtype: {dtype}")
def into_device(device: str | torch.device | None) -> torch.device:
if isinstance(device, str):
return torch.device(device)
elif isinstance(device, torch.device):
return device
elif isinstance(device, int):
return torch.device(f"cuda:{device}")
else:
return torch.device("cuda:0")
def load_config(
name: ModelVersion = ModelVersion.flux_dev,
flux_path: str | None = None,
ae_path: str | None = None,
text_enc_path: str | None = None,
text_enc_device: str | torch.device | None = None,
ae_device: str | torch.device | None = None,
flux_device: str | torch.device | None = None,
flow_dtype: str = "float16",
ae_dtype: str = "bfloat16",
text_enc_dtype: str = "bfloat16",
num_to_quant: Optional[int] = 20,
compile_extras: bool = False,
compile_blocks: bool = False,
offload_text_enc: bool = False,
offload_ae: bool = False,
offload_flow: bool = False,
quant_text_enc: Optional[Literal["float8", "qint2", "qint4", "qint8"]] = None,
quant_ae: bool = False,
prequantized_flow: bool = False,
quantize_modulation: bool = True,
quantize_flow_embedder_layers: bool = False,
) -> ModelSpec:
"""
Load a model configuration using the passed arguments.
"""
text_enc_device = str(parse_device(text_enc_device))
ae_device = str(parse_device(ae_device))
flux_device = str(parse_device(flux_device))
return ModelSpec(
version=name,
repo_id=(
"black-forest-labs/FLUX.1-dev"
if name == ModelVersion.flux_dev
else "black-forest-labs/FLUX.1-schnell"
),
repo_flow=(
"flux1-dev.sft" if name == ModelVersion.flux_dev else "flux1-schnell.sft"
),
repo_ae="ae.sft",
ckpt_path=flux_path,
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=name == ModelVersion.flux_dev,
),
ae_path=ae_path,
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
text_enc_path=text_enc_path,
text_enc_device=text_enc_device,
ae_device=ae_device,
flux_device=flux_device,
flow_dtype=flow_dtype,
ae_dtype=ae_dtype,
text_enc_dtype=text_enc_dtype,
text_enc_max_length=512 if name == ModelVersion.flux_dev else 256,
num_to_quant=num_to_quant,
compile_extras=compile_extras,
compile_blocks=compile_blocks,
offload_flow=offload_flow,
offload_text_encoder=offload_text_enc,
offload_vae=offload_ae,
text_enc_quantization_dtype={
"float8": QuantizationDtype.qfloat8,
"qint2": QuantizationDtype.qint2,
"qint4": QuantizationDtype.qint4,
"qint8": QuantizationDtype.qint8,
}.get(quant_text_enc, None),
ae_quantization_dtype=QuantizationDtype.qfloat8 if quant_ae else None,
prequantized_flow=prequantized_flow,
quantize_modulation=quantize_modulation,
quantize_flow_embedder_layers=quantize_flow_embedder_layers,
)
def load_config_from_path(path: str) -> ModelSpec:
path_path = Path(path)
if not path_path.exists():
raise ValueError(f"Path {path} does not exist")
if not path_path.is_file():
raise ValueError(f"Path {path} is not a file")
return ModelSpec(**json.loads(path_path.read_text()))
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
if len(missing) > 0 and len(unexpected) > 0:
logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
logger.warning("\n" + "-" * 79 + "\n")
logger.warning(
f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)
)
elif len(missing) > 0:
logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
elif len(unexpected) > 0:
logger.warning(
f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)
)
def load_flow_model(config: ModelSpec) -> Flux:
ckpt_path = config.ckpt_path
FluxClass = Flux
with torch.device("meta"):
model = FluxClass(config, dtype=into_dtype(config.flow_dtype))
if not config.prequantized_flow:
model.type(into_dtype(config.flow_dtype))
if ckpt_path is not None:
# load_sft doesn't support torch.device
sd = load_sft(ckpt_path, device="cpu")
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
print_load_warning(missing, unexpected)
if not config.prequantized_flow:
model.type(into_dtype(config.flow_dtype))
return model
def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]:
clip = HFEmbedder(
config.clip_path,
max_length=77,
torch_dtype=into_dtype(config.text_enc_dtype),
device=into_device(config.text_enc_device).index or 0,
is_clip=True,
quantization_dtype=config.clip_quantization_dtype,
)
t5 = HFEmbedder(
config.text_enc_path,
max_length=config.text_enc_max_length,
torch_dtype=into_dtype(config.text_enc_dtype),
device=into_device(config.text_enc_device).index or 0,
quantization_dtype=config.text_enc_quantization_dtype,
)
return clip, t5
def load_autoencoder(config: ModelSpec) -> AutoEncoder:
ckpt_path = config.ae_path
with torch.device("meta" if ckpt_path is not None else config.ae_device):
ae = AutoEncoder(config.ae_params).to(into_dtype(config.ae_dtype))
if ckpt_path is not None:
sd = load_sft(ckpt_path, device=str(config.ae_device))
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
print_load_warning(missing, unexpected)
ae.to(device=into_device(config.ae_device), dtype=into_dtype(config.ae_dtype))
if config.ae_quantization_dtype is not None:
from float8_quantize import recursive_swap_linears
recursive_swap_linears(ae)
if config.offload_vae:
ae.to("cpu")
torch.cuda.empty_cache()
return ae
class LoadedModels(BaseModel):
flow: Flux
ae: AutoEncoder
clip: HFEmbedder
t5: HFEmbedder
config: ModelSpec
model_config = {
"arbitrary_types_allowed": True,
"use_enum_values": True,
}
def load_models_from_config_path(
path: str,
) -> LoadedModels:
config = load_config_from_path(path)
clip, t5 = load_text_encoders(config)
return LoadedModels(
flow=load_flow_model(config),
ae=load_autoencoder(config),
clip=clip,
t5=t5,
config=config,
)
def load_models_from_config(config: ModelSpec) -> LoadedModels:
clip, t5 = load_text_encoders(config)
return LoadedModels(
flow=load_flow_model(config),
ae=load_autoencoder(config),
clip=clip,
t5=t5,
config=config,
)