Skip to content

Commit

Permalink
Add fx pass to remove sdpa composite zero mask input (#62)
Browse files Browse the repository at this point in the history
* init

* add copyright

* Add test
  • Loading branch information
chunnienc authored Jun 20, 2024
1 parent 1a4fe12 commit 0a6c9b3
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ai_edge_torch/convert/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ai_edge_torch.convert.fx_passes import InjectMlirDebuginfoPass
from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass
from ai_edge_torch.convert.fx_passes import run_passes
from ai_edge_torch.generative.fx_passes import run_generative_passes
from ai_edge_torch.quantize import quant_config as qcfg

os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
Expand All @@ -38,6 +39,7 @@
def _run_convert_passes(
exported_program: ExportedProgram,
) -> ExportedProgram:
exported_program = run_generative_passes(exported_program)
return run_passes(
exported_program,
[
Expand Down
31 changes: 31 additions & 0 deletions ai_edge_torch/generative/fx_passes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch

from ai_edge_torch.convert.fx_passes import CanonicalizePass
from ai_edge_torch.convert.fx_passes import run_passes
from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass # NOQA


def run_generative_passes(
exported_program: torch.export.ExportedProgram,
) -> torch.export.ExportedProgram:
return run_passes(
exported_program,
[
RemoveSDPACompositeZeroMaskPass(),
CanonicalizePass(),
],
)
47 changes: 47 additions & 0 deletions ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch

from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA


class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):

def is_zero_tensor_node(self, node: torch.fx.Node):
return node.target == torch.ops.aten.zeros.default

def call(self, exported_program: torch.export.ExportedProgram):
graph = exported_program.graph_module.graph
for node in graph.nodes:
if not (
node.op == "call_function"
and node.target == torch.ops.xla.mark_tensor.default
):
continue

source, name, io_position, id, is_input = node.args[:5]
# Composite info:
# - name: odml.scaled_dot_product_attention
# - inputs: q, k, v, mask
if name == "odml.scaled_dot_product_attention" and is_input and io_position == 3:
if self.is_zero_tensor_node(source):
# Remove the mark_tensor call on the mask input by
# replacing the target with an identity function.
node.target = lambda *args, **kwargs: args[0]

exported_program.graph_module.graph.lint()
exported_program.graph_module.recompile()
return ExportedProgramPassResult(exported_program, True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import re
from typing import Callable, Union
import unittest

import torch
import torch_xla

from ai_edge_torch.convert.fx_passes import CanonicalizePass
from ai_edge_torch.convert.fx_passes import run_passes
from ai_edge_torch.generative.fx_passes import RemoveSDPACompositeZeroMaskPass
from ai_edge_torch.generative.layers.attention import SelfAttention
import ai_edge_torch.generative.layers.model_config as layers_cfg
import ai_edge_torch.generative.layers.unet.builder as unet_builder
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg


def _export_to_stablehlo(func: Union[torch.nn.Module, Callable], export_args):
if not isinstance(func, torch.nn.Module):

class TestModule(torch.nn.Module):

def forward(self, *args, **kwargs):
return func(*args, **kwargs)

module = TestModule().eval()
else:
module = func

exported_program = torch.export.export(module, export_args)
exported_program = run_passes(
exported_program,
[
RemoveSDPACompositeZeroMaskPass(),
CanonicalizePass(),
],
)

return torch_xla.stablehlo.exported_program_to_stablehlo(
exported_program
).get_stablehlo_text()


class TestRemoveSDPAZeroMaskPass(unittest.TestCase):

def test_self_attention_no_zero_mask_composite_input(self):
class SampleSdpaBlock(torch.nn.Module):
"""Sample attention block with SDPA"""

def __init__(self, config: unet_cfg.AttentionBlock2DConfig):
super().__init__()
self.config = config
self.attention = SelfAttention(
config.attention_batch_size,
config.dim,
config.attention_config,
0,
enable_hlfb=config.enable_hlfb,
)

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
B, C, H, W = input_tensor.shape
x = input_tensor
x = input_tensor.view(B, C, H * W)
x = x.transpose(-1, -2)
# x = x.contiguous() # Prevent BATCH_MATMUL op in converted tflite.
x = self.attention(x)
x = x.transpose(-1, -2)
x = x.view(B, C, H, W)
return x

def get_model_config() -> unet_cfg.AttentionBlock2DConfig:
"""Get configs for the Decoder of Stable Diffusion v1.5"""
in_channels = 3
latent_channels = 4
out_channels = 3
block_out_channels = [128, 256, 512, 512]
scaling_factor = 0.18215
layers_per_block = 3

norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.GROUP_NORM, group_num=32
)

return unet_cfg.AttentionBlock2DConfig(
dim=block_out_channels[-1],
normalization_config=norm_config,
attention_config=layers_cfg.AttentionConfig(
num_heads=1,
num_query_groups=1,
qkv_use_bias=True,
output_proj_use_bias=True,
enable_kv_cache=False,
qkv_transpose_before_split=True,
rotary_percentage=0.0,
),
)

stablehlo = _export_to_stablehlo(
SampleSdpaBlock(get_model_config()).eval(), (torch.rand(1, 512, 64, 64),)
)
print(stablehlo)
self.assertTrue(
re.search(
'stablehlo\.composite "odml\.scaled_dot_product_attention" %\d+, %\d+, %\d+ {',
stablehlo,
)
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 0a6c9b3

Please sign in to comment.