diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 1eb0fea62e6..ea5a9a22bf7 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -29,7 +29,7 @@ title="FLUX Text Encoding", tags=["prompt", "conditioning", "flux"], category="conditioning", - version="1.1.0", + version="1.2.0", classification=Classification.Prototype, ) class FluxTextEncoderInvocation(BaseInvocation): @@ -48,6 +48,11 @@ class FluxTextEncoderInvocation(BaseInvocation): t5_max_seq_len: Literal[256, 512] = InputField( description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models." ) + use_short_t5_seq_len: bool = InputField( + description="Use a shorter sequence length for the T5 encoder if a short prompt is used. This can improve " + + "performance and reduce peak memory, but may result in slightly different image outputs.", + default=True, + ) prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea) mask: Optional[TensorField] = InputField( default=None, description="A mask defining the region that this conditioning prompt applies to." @@ -74,6 +79,10 @@ def _t5_encode(self, context: InvocationContext) -> torch.Tensor: prompt = [self.prompt] + valid_seq_lens = [self.t5_max_seq_len] + if self.use_short_t5_seq_len: + valid_seq_lens = [128, 256, 512] + with ( t5_text_encoder_info as t5_text_encoder, t5_tokenizer_info as t5_tokenizer, @@ -81,10 +90,10 @@ def _t5_encode(self, context: InvocationContext) -> torch.Tensor: assert isinstance(t5_text_encoder, T5EncoderModel) assert isinstance(t5_tokenizer, T5Tokenizer) - t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len) + t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False) context.util.signal_progress("Running T5 encoder") - prompt_embeds = t5_encoder(prompt) + prompt_embeds = t5_encoder(prompt, valid_seq_lens) assert isinstance(prompt_embeds, torch.Tensor) return prompt_embeds @@ -122,10 +131,10 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor: # There are currently no supported CLIP quantized models. Add support here if needed. raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}") - clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77) + clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True) context.util.signal_progress("Running CLIP encoder") - pooled_prompt_embeds = clip_encoder(prompt) + pooled_prompt_embeds = clip_encoder(prompt, [77]) assert isinstance(pooled_prompt_embeds, torch.Tensor) return pooled_prompt_embeds diff --git a/invokeai/backend/flux/modules/conditioner.py b/invokeai/backend/flux/modules/conditioner.py index de6d8256c4f..7207eea45fa 100644 --- a/invokeai/backend/flux/modules/conditioner.py +++ b/invokeai/backend/flux/modules/conditioner.py @@ -1,32 +1,43 @@ # Initially pulled from https://github.com/black-forest-labs/flux + from torch import Tensor, nn from transformers import PreTrainedModel, PreTrainedTokenizer class HFEncoder(nn.Module): - def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int): + def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool): super().__init__() - self.max_length = max_length self.is_clip = is_clip self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" self.tokenizer = tokenizer self.hf_module = encoder self.hf_module = self.hf_module.eval().requires_grad_(False) - def forward(self, text: list[str]) -> Tensor: + def forward(self, text: list[str], valid_seq_lens: list[int]) -> Tensor: + valid_seq_lens = sorted(valid_seq_lens) batch_encoding = self.tokenizer( text, truncation=True, - max_length=self.max_length, - return_length=False, + max_length=max(valid_seq_lens), + return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) + seq_len: int = batch_encoding["length"][0].item() + # Find selected_seq_len, the minimum valid sequence length that can contain all of the input tokens. + selected_seq_len = valid_seq_lens[-1] + for len in valid_seq_lens: + if len >= seq_len: + selected_seq_len = len + break + + input_ids = batch_encoding["input_ids"][..., :selected_seq_len] + outputs = self.hf_module( - input_ids=batch_encoding["input_ids"].to(self.hf_module.device), + input_ids=input_ids.to(self.hf_module.device), attention_mask=None, output_hidden_states=False, )