Skip to content

Commit

Permalink
Dynamically select smaller t5 seq len to save inference time.
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanJDick committed Nov 29, 2024
1 parent 54b7f9a commit 4581a37
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
19 changes: 14 additions & 5 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
title="FLUX Text Encoding",
tags=["prompt", "conditioning", "flux"],
category="conditioning",
version="1.1.0",
version="1.2.0",
classification=Classification.Prototype,
)
class FluxTextEncoderInvocation(BaseInvocation):
Expand All @@ -48,6 +48,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_max_seq_len: Literal[256, 512] = InputField(
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
)
use_short_t5_seq_len: bool = InputField(
description="Use a shorter sequence length for the T5 encoder if a short prompt is used. This can improve "
+ "performance and reduce peak memory, but may result in slightly different image outputs.",
default=True,
)
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
Expand All @@ -74,17 +79,21 @@ def _t5_encode(self, context: InvocationContext) -> torch.Tensor:

prompt = [self.prompt]

valid_seq_lens = [self.t5_max_seq_len]
if self.use_short_t5_seq_len:
valid_seq_lens = [128, 256, 512]

with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, T5Tokenizer)

t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False)

context.util.signal_progress("Running T5 encoder")
prompt_embeds = t5_encoder(prompt)
prompt_embeds = t5_encoder(prompt, valid_seq_lens)

assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
Expand Down Expand Up @@ -122,10 +131,10 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
# There are currently no supported CLIP quantized models. Add support here if needed.
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")

clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True)

context.util.signal_progress("Running CLIP encoder")
pooled_prompt_embeds = clip_encoder(prompt)
pooled_prompt_embeds = clip_encoder(prompt, [77])

assert isinstance(pooled_prompt_embeds, torch.Tensor)
return pooled_prompt_embeds
Expand Down
23 changes: 17 additions & 6 deletions invokeai/backend/flux/modules/conditioner.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,43 @@
# Initially pulled from https://github.com/black-forest-labs/flux


from torch import Tensor, nn
from transformers import PreTrainedModel, PreTrainedTokenizer


class HFEncoder(nn.Module):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool):
super().__init__()
self.max_length = max_length
self.is_clip = is_clip
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
self.tokenizer = tokenizer
self.hf_module = encoder
self.hf_module = self.hf_module.eval().requires_grad_(False)

def forward(self, text: list[str]) -> Tensor:
def forward(self, text: list[str], valid_seq_lens: list[int]) -> Tensor:
valid_seq_lens = sorted(valid_seq_lens)
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
max_length=max(valid_seq_lens),
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)

seq_len: int = batch_encoding["length"][0].item()
# Find selected_seq_len, the minimum valid sequence length that can contain all of the input tokens.
selected_seq_len = valid_seq_lens[-1]
for len in valid_seq_lens:
if len >= seq_len:
selected_seq_len = len
break

input_ids = batch_encoding["input_ids"][..., :selected_seq_len]

outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
input_ids=input_ids.to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
Expand Down

0 comments on commit 4581a37

Please sign in to comment.