Skip to content

Commit

Permalink
Comment unused _prepare_unrestricted_attn_mask(...) for future refere…
Browse files Browse the repository at this point in the history
…nce.
  • Loading branch information
RyanJDick committed Nov 27, 2024
1 parent 3ebd8d6 commit 6565cea
Showing 1 changed file with 57 additions and 61 deletions.
118 changes: 57 additions & 61 deletions invokeai/backend/flux/extensions/regional_prompting_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,16 @@
class RegionalPromptingExtension:
"""A class for managing regional prompting with FLUX.
Implementation inspired by: https://arxiv.org/pdf/2411.02395
This implementation is inspired by https://arxiv.org/pdf/2411.02395 (though there are significant differences).
"""

def __init__(
self,
regional_text_conditioning: FluxRegionalTextConditioning,
restricted_attn_mask: torch.Tensor | None = None,
# unrestricted_attn_mask: torch.Tensor | None = None,
):
self.regional_text_conditioning = regional_text_conditioning
self.restricted_attn_mask = restricted_attn_mask
# self.unrestricted_attn_mask = unrestricted_attn_mask

def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
order = [self.restricted_attn_mask, None]
Expand All @@ -45,69 +43,67 @@ def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], i
attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask(
regional_text_conditioning, img_seq_len
)
# attn_mask_with_unrestricted_img_self_attn = cls._prepare_unrestricted_attn_mask(
# regional_text_conditioning, img_seq_len
# )
return cls(
regional_text_conditioning=regional_text_conditioning,
restricted_attn_mask=attn_mask_with_restricted_img_self_attn,
# unrestricted_attn_mask=attn_mask_with_unrestricted_img_self_attn,
)

@classmethod
def _prepare_unrestricted_attn_mask(
cls,
regional_text_conditioning: FluxRegionalTextConditioning,
img_seq_len: int,
) -> torch.Tensor:
"""Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that:
- img self-attention is not masked.
- img regions attend to both txt within their own region and to global prompts.
"""
device = TorchDevice.choose_torch_device()

# Infer txt_seq_len from the t5_embeddings tensor.
txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]

# In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
# Concatenation happens in the following order: [txt_seq, img_seq].
# There are 4 portions of the attention mask to consider as we prepare it:
# 1. txt attends to itself
# 2. txt attends to corresponding regional img
# 3. regional img attends to corresponding txt
# 4. regional img attends to itself

# Initialize empty attention mask.
regional_attention_mask = torch.zeros(
(txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
)

for image_mask, t5_embedding_range in zip(
regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
):
# 1. txt attends to itself
regional_attention_mask[
t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
] = 1.0

# 2. txt attends to corresponding regional img
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
fill_value = image_mask.view(1, img_seq_len) if image_mask is not None else 1.0
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = fill_value

# 3. regional img attends to corresponding txt
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value

# 4. regional img attends to itself
# Allow unrestricted img self attention.
regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0

# Convert attention mask to boolean.
regional_attention_mask = regional_attention_mask > 0.5

return regional_attention_mask
# Keeping _prepare_unrestricted_attn_mask for reference as an alternative masking strategy:
#
# @classmethod
# def _prepare_unrestricted_attn_mask(
# cls,
# regional_text_conditioning: FluxRegionalTextConditioning,
# img_seq_len: int,
# ) -> torch.Tensor:
# """Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that:
# - img self-attention is not masked.
# - img regions attend to both txt within their own region and to global prompts.
# """
# device = TorchDevice.choose_torch_device()

# # Infer txt_seq_len from the t5_embeddings tensor.
# txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]

# # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
# # Concatenation happens in the following order: [txt_seq, img_seq].
# # There are 4 portions of the attention mask to consider as we prepare it:
# # 1. txt attends to itself
# # 2. txt attends to corresponding regional img
# # 3. regional img attends to corresponding txt
# # 4. regional img attends to itself

# # Initialize empty attention mask.
# regional_attention_mask = torch.zeros(
# (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
# )

# for image_mask, t5_embedding_range in zip(
# regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
# ):
# # 1. txt attends to itself
# regional_attention_mask[
# t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
# ] = 1.0

# # 2. txt attends to corresponding regional img
# # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
# fill_value = image_mask.view(1, img_seq_len) if image_mask is not None else 1.0
# regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = fill_value

# # 3. regional img attends to corresponding txt
# # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
# fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0
# regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value

# # 4. regional img attends to itself
# # Allow unrestricted img self attention.
# regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0

# # Convert attention mask to boolean.
# regional_attention_mask = regional_attention_mask > 0.5

# return regional_attention_mask

@classmethod
def _prepare_restricted_attn_mask(
Expand Down

0 comments on commit 6565cea

Please sign in to comment.