Skip to content

Commit

Permalink
Temporarily disable HLFB for stable diffusion (#104)
Browse files Browse the repository at this point in the history
* Tmp disable SDPA

* Update
  • Loading branch information
yichunk authored Jul 31, 2024
1 parent e79fdc0 commit 4c40530
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
qkv_fused_interleaved=False,
rotary_percentage=0.0,
),
enable_hlfb=False,
)

mid_block_config = unet_cfg.MidBlock2DConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,13 +294,15 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
attention_batch_size=config.transformer_batch_size,
normalization_config=config.transformer_norm_config,
attention_config=attention_config,
enable_hlfb=False,
),
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
query_dim=output_channel,
cross_dim=config.transformer_cross_attention_dim,
attention_batch_size=config.transformer_batch_size,
normalization_config=config.transformer_norm_config,
attention_config=attention_config,
enable_hlfb=False,
),
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
Expand Down Expand Up @@ -354,13 +356,15 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
attention_batch_size=config.transformer_batch_size,
normalization_config=config.transformer_norm_config,
attention_config=attention_config,
enable_hlfb=False,
),
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
query_dim=mid_block_channels,
cross_dim=config.transformer_cross_attention_dim,
attention_batch_size=config.transformer_batch_size,
normalization_config=config.transformer_norm_config,
attention_config=attention_config,
enable_hlfb=False,
),
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
Expand Down Expand Up @@ -415,13 +419,15 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
attention_batch_size=config.transformer_batch_size,
normalization_config=config.transformer_norm_config,
attention_config=attention_config,
enable_hlfb=False,
),
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
query_dim=output_channel,
cross_dim=config.transformer_cross_attention_dim,
attention_batch_size=config.transformer_batch_size,
normalization_config=config.transformer_norm_config,
attention_config=attention_config,
enable_hlfb=False,
),
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
Expand Down

0 comments on commit 4c40530

Please sign in to comment.