Skip to content

Commit

Permalink
small convenience for setting which layers get full attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 12, 2023
1 parent 9e43418 commit 7558d3f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def __init__(
sinusoidal_pos_emb_theta = 10000,
attn_dim_head = 32,
attn_heads = 4,
full_attn = None,#default(F, F, F, T)
full_attn = None, # defaults to full attention only for inner most layer
flash_attn = False
):
super().__init__()
Expand Down Expand Up @@ -328,7 +328,8 @@ def __init__(
# attention

if not full_attn:
full_attn = tuple([False] * (len(dim_mults)-1) + [True])
full_attn = (*((False,) * (len(dim_mults) - 1)), True)

num_stages = len(dim_mults)
full_attn = cast_tuple(full_attn, num_stages)
attn_heads = cast_tuple(attn_heads, num_stages)
Expand Down
2 changes: 1 addition & 1 deletion denoising_diffusion_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.9.0'
__version__ = '1.9.1'

0 comments on commit 7558d3f

Please sign in to comment.