From 64364e791184447fa695831e49392cffd0b5ab35 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 27 Nov 2024 22:40:10 +0000 Subject: [PATCH] Short-circuit if there are no region masks in FLUX and don't apply attention masking. --- .../regional_prompting_extension.py | 41 ++++++++----------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index 834463137a3..f1086c32866 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -110,12 +110,25 @@ def _prepare_restricted_attn_mask( cls, regional_text_conditioning: FluxRegionalTextConditioning, img_seq_len: int, - ) -> torch.Tensor: + ) -> torch.Tensor | None: """Prepare a 'restricted' attention mask. In this context, 'restricted' means that: - img self-attention is only allowed within regions. - img regions only attend to txt within their own region, not to global prompts. - """ + # Identify background region. I.e. the region that is not covered by any region masks. + background_region_mask: None | torch.Tensor = None + for image_mask in regional_text_conditioning.image_masks: + if image_mask is not None: + if background_region_mask is None: + background_region_mask = torch.ones_like(image_mask) + background_region_mask *= 1 - image_mask + + if background_region_mask is None: + # There are no region masks, short-circuit and return None. + # TODO(ryand): We could restrict txt-txt attention across multiple global prompts, but this would + # is a rare use case and would make the logic here significantly more complicated. + return None + device = TorchDevice.choose_torch_device() # Infer txt_seq_len from the t5_embeddings tensor. @@ -134,14 +147,6 @@ def _prepare_restricted_attn_mask( (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16 ) - # Identify background region. I.e. the region that is not covered by any region masks. - background_region_mask: None | torch.Tensor = None - for image_mask in regional_text_conditioning.image_masks: - if image_mask is not None: - if background_region_mask is None: - background_region_mask = torch.ones_like(image_mask) - background_region_mask *= 1 - image_mask - for image_mask, t5_embedding_range in zip( regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True ): @@ -167,10 +172,6 @@ def _prepare_restricted_attn_mask( image_mask = image_mask.view(img_seq_len, 1) regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T else: - if background_region_mask is None: - # There are no region masks, so we don't need to do anything here - this case is handled below. - continue - # We don't allow attention between non-background image regions and global prompts. This helps to ensure # that regions focus on their local prompts. We do, however, allow attention between background regions # and global prompts. If we didn't do this, then the background regions would not attend to any txt @@ -188,15 +189,9 @@ def _prepare_restricted_attn_mask( background_region_mask.view(img_seq_len, 1) ) - # Handle image background regions. - if background_region_mask is None: - # There are no region masks, so allow unrestricted img-img attention, and unrestricted img-txt attention. - regional_attention_mask[txt_seq_len:, :] = 1.0 - regional_attention_mask[:, txt_seq_len:] = 1.0 - else: - # Allow background regions to attend to themselves. - regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1) - regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len) + # Allow background regions to attend to themselves. + regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1) + regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len) # Convert attention mask to boolean. regional_attention_mask = regional_attention_mask > 0.5