From 8d04ec3f95f3fac50c3dbca2006138f56159a228 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 29 Nov 2024 16:11:51 +0000 Subject: [PATCH] Improve docs related to dynamic T5 sequence length selection. --- invokeai/app/invocations/flux_text_encoder.py | 4 +++- invokeai/backend/flux/modules/conditioner.py | 12 +++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index ea5a9a22bf7..dd6ed6b158b 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -81,7 +81,9 @@ def _t5_encode(self, context: InvocationContext) -> torch.Tensor: valid_seq_lens = [self.t5_max_seq_len] if self.use_short_t5_seq_len: - valid_seq_lens = [128, 256, 512] + # We allow a minimum sequence length of 128. Going too short results in more significant image chagnes. + valid_seq_lens = list(range(128, self.t5_max_seq_len, 128)) + valid_seq_lens.append(self.t5_max_seq_len) with ( t5_text_encoder_info as t5_text_encoder, diff --git a/invokeai/backend/flux/modules/conditioner.py b/invokeai/backend/flux/modules/conditioner.py index 7207eea45fa..6fb17a22975 100644 --- a/invokeai/backend/flux/modules/conditioner.py +++ b/invokeai/backend/flux/modules/conditioner.py @@ -15,7 +15,17 @@ def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_ self.hf_module = self.hf_module.eval().requires_grad_(False) def forward(self, text: list[str], valid_seq_lens: list[int]) -> Tensor: + """Encode text into a tensor. + + Args: + text: A list of text prompts to encode. + valid_seq_lens: A list of valid sequence lengths. The shortest valid sequence length that can contain the + text will be used. If the largest valid sequence length cannot contain the text, the encoding will be + truncated. + """ valid_seq_lens = sorted(valid_seq_lens) + + # Perform initial encoding with the maximum valid sequence length. batch_encoding = self.tokenizer( text, truncation=True, @@ -26,8 +36,8 @@ def forward(self, text: list[str], valid_seq_lens: list[int]) -> Tensor: return_tensors="pt", ) - seq_len: int = batch_encoding["length"][0].item() # Find selected_seq_len, the minimum valid sequence length that can contain all of the input tokens. + seq_len: int = batch_encoding["length"][0].item() selected_seq_len = valid_seq_lens[-1] for len in valid_seq_lens: if len >= seq_len: