Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dynamic FLUX T5 seq len (small speed / mem improvement) #7400

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
title="FLUX Text Encoding",
tags=["prompt", "conditioning", "flux"],
category="conditioning",
version="1.1.0",
version="1.2.0",
classification=Classification.Prototype,
)
class FluxTextEncoderInvocation(BaseInvocation):
Expand All @@ -48,6 +48,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_max_seq_len: Literal[256, 512] = InputField(
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
)
use_short_t5_seq_len: bool = InputField(
description="Use a shorter sequence length for the T5 encoder if a short prompt is used. This can improve "
+ "performance and reduce peak memory, but may result in slightly different image outputs.",
default=True,
)
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
Expand All @@ -74,17 +79,23 @@ def _t5_encode(self, context: InvocationContext) -> torch.Tensor:

prompt = [self.prompt]

valid_seq_lens = [self.t5_max_seq_len]
if self.use_short_t5_seq_len:
# We allow a minimum sequence length of 128. Going too short results in more significant image chagnes.
valid_seq_lens = list(range(128, self.t5_max_seq_len, 128))
valid_seq_lens.append(self.t5_max_seq_len)

with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, T5Tokenizer)

t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False)

context.util.signal_progress("Running T5 encoder")
prompt_embeds = t5_encoder(prompt)
prompt_embeds = t5_encoder(prompt, valid_seq_lens)

assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
Expand Down Expand Up @@ -122,10 +133,10 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
# There are currently no supported CLIP quantized models. Add support here if needed.
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")

clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True)

context.util.signal_progress("Running CLIP encoder")
pooled_prompt_embeds = clip_encoder(prompt)
pooled_prompt_embeds = clip_encoder(prompt, [77])

assert isinstance(pooled_prompt_embeds, torch.Tensor)
return pooled_prompt_embeds
Expand Down
33 changes: 27 additions & 6 deletions invokeai/backend/flux/modules/conditioner.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,53 @@
# Initially pulled from https://github.com/black-forest-labs/flux


from torch import Tensor, nn
from transformers import PreTrainedModel, PreTrainedTokenizer


class HFEncoder(nn.Module):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool):
super().__init__()
self.max_length = max_length
self.is_clip = is_clip
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
self.tokenizer = tokenizer
self.hf_module = encoder
self.hf_module = self.hf_module.eval().requires_grad_(False)

def forward(self, text: list[str]) -> Tensor:
def forward(self, text: list[str], valid_seq_lens: list[int]) -> Tensor:
"""Encode text into a tensor.

Args:
text: A list of text prompts to encode.
valid_seq_lens: A list of valid sequence lengths. The shortest valid sequence length that can contain the
text will be used. If the largest valid sequence length cannot contain the text, the encoding will be
truncated.
"""
valid_seq_lens = sorted(valid_seq_lens)

# Perform initial encoding with the maximum valid sequence length.
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
max_length=max(valid_seq_lens),
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)

# Find selected_seq_len, the minimum valid sequence length that can contain all of the input tokens.
seq_len: int = batch_encoding["length"][0].item()
selected_seq_len = valid_seq_lens[-1]
for len in valid_seq_lens:
if len >= seq_len:
selected_seq_len = len
break

input_ids = batch_encoding["input_ids"][..., :selected_seq_len]

outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
input_ids=input_ids.to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
Expand Down