Skip to content

Commit

Permalink
Short-circuit if there are no region masks in FLUX and don't apply at…
Browse files Browse the repository at this point in the history
…tention masking.
  • Loading branch information
RyanJDick committed Nov 27, 2024
1 parent 6565cea commit 64364e7
Showing 1 changed file with 18 additions and 23 deletions.
41 changes: 18 additions & 23 deletions invokeai/backend/flux/extensions/regional_prompting_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,25 @@ def _prepare_restricted_attn_mask(
cls,
regional_text_conditioning: FluxRegionalTextConditioning,
img_seq_len: int,
) -> torch.Tensor:
) -> torch.Tensor | None:
"""Prepare a 'restricted' attention mask. In this context, 'restricted' means that:
- img self-attention is only allowed within regions.
- img regions only attend to txt within their own region, not to global prompts.
"""
# Identify background region. I.e. the region that is not covered by any region masks.
background_region_mask: None | torch.Tensor = None
for image_mask in regional_text_conditioning.image_masks:
if image_mask is not None:
if background_region_mask is None:
background_region_mask = torch.ones_like(image_mask)
background_region_mask *= 1 - image_mask

if background_region_mask is None:
# There are no region masks, short-circuit and return None.
# TODO(ryand): We could restrict txt-txt attention across multiple global prompts, but this would
# is a rare use case and would make the logic here significantly more complicated.
return None

device = TorchDevice.choose_torch_device()

# Infer txt_seq_len from the t5_embeddings tensor.
Expand All @@ -134,14 +147,6 @@ def _prepare_restricted_attn_mask(
(txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
)

# Identify background region. I.e. the region that is not covered by any region masks.
background_region_mask: None | torch.Tensor = None
for image_mask in regional_text_conditioning.image_masks:
if image_mask is not None:
if background_region_mask is None:
background_region_mask = torch.ones_like(image_mask)
background_region_mask *= 1 - image_mask

for image_mask, t5_embedding_range in zip(
regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
):
Expand All @@ -167,10 +172,6 @@ def _prepare_restricted_attn_mask(
image_mask = image_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
else:
if background_region_mask is None:
# There are no region masks, so we don't need to do anything here - this case is handled below.
continue

# We don't allow attention between non-background image regions and global prompts. This helps to ensure
# that regions focus on their local prompts. We do, however, allow attention between background regions
# and global prompts. If we didn't do this, then the background regions would not attend to any txt
Expand All @@ -188,15 +189,9 @@ def _prepare_restricted_attn_mask(
background_region_mask.view(img_seq_len, 1)
)

# Handle image background regions.
if background_region_mask is None:
# There are no region masks, so allow unrestricted img-img attention, and unrestricted img-txt attention.
regional_attention_mask[txt_seq_len:, :] = 1.0
regional_attention_mask[:, txt_seq_len:] = 1.0
else:
# Allow background regions to attend to themselves.
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
# Allow background regions to attend to themselves.
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)

# Convert attention mask to boolean.
regional_attention_mask = regional_attention_mask > 0.5
Expand Down

0 comments on commit 64364e7

Please sign in to comment.