From 6565cea039be1b808c7055622584a2361eff6861 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 27 Nov 2024 22:16:44 +0000 Subject: [PATCH] Comment unused _prepare_unrestricted_attn_mask(...) for future reference. --- .../regional_prompting_extension.py | 118 +++++++++--------- 1 file changed, 57 insertions(+), 61 deletions(-) diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index f5f203af695..834463137a3 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -12,18 +12,16 @@ class RegionalPromptingExtension: """A class for managing regional prompting with FLUX. - Implementation inspired by: https://arxiv.org/pdf/2411.02395 + This implementation is inspired by https://arxiv.org/pdf/2411.02395 (though there are significant differences). """ def __init__( self, regional_text_conditioning: FluxRegionalTextConditioning, restricted_attn_mask: torch.Tensor | None = None, - # unrestricted_attn_mask: torch.Tensor | None = None, ): self.regional_text_conditioning = regional_text_conditioning self.restricted_attn_mask = restricted_attn_mask - # self.unrestricted_attn_mask = unrestricted_attn_mask def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None: order = [self.restricted_attn_mask, None] @@ -45,69 +43,67 @@ def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], i attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask( regional_text_conditioning, img_seq_len ) - # attn_mask_with_unrestricted_img_self_attn = cls._prepare_unrestricted_attn_mask( - # regional_text_conditioning, img_seq_len - # ) return cls( regional_text_conditioning=regional_text_conditioning, restricted_attn_mask=attn_mask_with_restricted_img_self_attn, - # unrestricted_attn_mask=attn_mask_with_unrestricted_img_self_attn, ) - @classmethod - def _prepare_unrestricted_attn_mask( - cls, - regional_text_conditioning: FluxRegionalTextConditioning, - img_seq_len: int, - ) -> torch.Tensor: - """Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that: - - img self-attention is not masked. - - img regions attend to both txt within their own region and to global prompts. - """ - device = TorchDevice.choose_torch_device() - - # Infer txt_seq_len from the t5_embeddings tensor. - txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1] - - # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied. - # Concatenation happens in the following order: [txt_seq, img_seq]. - # There are 4 portions of the attention mask to consider as we prepare it: - # 1. txt attends to itself - # 2. txt attends to corresponding regional img - # 3. regional img attends to corresponding txt - # 4. regional img attends to itself - - # Initialize empty attention mask. - regional_attention_mask = torch.zeros( - (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16 - ) - - for image_mask, t5_embedding_range in zip( - regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True - ): - # 1. txt attends to itself - regional_attention_mask[ - t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end - ] = 1.0 - - # 2. txt attends to corresponding regional img - # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. - fill_value = image_mask.view(1, img_seq_len) if image_mask is not None else 1.0 - regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = fill_value - - # 3. regional img attends to corresponding txt - # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. - fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0 - regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value - - # 4. regional img attends to itself - # Allow unrestricted img self attention. - regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0 - - # Convert attention mask to boolean. - regional_attention_mask = regional_attention_mask > 0.5 - - return regional_attention_mask + # Keeping _prepare_unrestricted_attn_mask for reference as an alternative masking strategy: + # + # @classmethod + # def _prepare_unrestricted_attn_mask( + # cls, + # regional_text_conditioning: FluxRegionalTextConditioning, + # img_seq_len: int, + # ) -> torch.Tensor: + # """Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that: + # - img self-attention is not masked. + # - img regions attend to both txt within their own region and to global prompts. + # """ + # device = TorchDevice.choose_torch_device() + + # # Infer txt_seq_len from the t5_embeddings tensor. + # txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1] + + # # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied. + # # Concatenation happens in the following order: [txt_seq, img_seq]. + # # There are 4 portions of the attention mask to consider as we prepare it: + # # 1. txt attends to itself + # # 2. txt attends to corresponding regional img + # # 3. regional img attends to corresponding txt + # # 4. regional img attends to itself + + # # Initialize empty attention mask. + # regional_attention_mask = torch.zeros( + # (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16 + # ) + + # for image_mask, t5_embedding_range in zip( + # regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True + # ): + # # 1. txt attends to itself + # regional_attention_mask[ + # t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end + # ] = 1.0 + + # # 2. txt attends to corresponding regional img + # # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. + # fill_value = image_mask.view(1, img_seq_len) if image_mask is not None else 1.0 + # regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = fill_value + + # # 3. regional img attends to corresponding txt + # # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. + # fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0 + # regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value + + # # 4. regional img attends to itself + # # Allow unrestricted img self attention. + # regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0 + + # # Convert attention mask to boolean. + # regional_attention_mask = regional_attention_mask > 0.5 + + # return regional_attention_mask @classmethod def _prepare_restricted_attn_mask(