Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[export] use tree_map for _flatten_dynamic_shapes #125415

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4561,6 +4561,35 @@ def forward(self, x, y, div="floor"):
self.assertEqual(div_spec.arg.name, "div")
self.assertEqual(div_spec.arg.value, "floor")

def test_nested_dynamic_shapes_spec(self):
class Foo(torch.nn.Module):
def forward(self, x):
(a0, a1), (b0, b1), (c0, c1, c2) = x
return a0 + a1 + b0 + b1 + c0 + c1 + c2

f = Foo()
inputs = (
(1, 2),
(
torch.randn(4, 4),
torch.randn(4, 4),
),
(
torch.randn(4, 4),
torch.randn(4, 4),
torch.randn(4, 4),
),
)
# make sure this gets parsed correctly as 7 individual inputs, not 3 tensors
dynamic_shapes = {
"x": (
(None, None),
(None, None),
(None, None, None),
)
}
export(f, (inputs,), dynamic_shapes=dynamic_shapes)


@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestOneOffModelExportResult(TestCase):
Expand Down
3 changes: 3 additions & 0 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Constraint,
dims,
dynamic_dim,
_combine_args,
)
from torch.export.exported_program import (
_disable_prexisiting_fake_mode,
Expand Down Expand Up @@ -175,9 +176,11 @@ def capture_pre_autograd_graph(
_restore_state_dict(f, m)

flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
combined_args = _combine_args(f, args, kwargs)
range_constraints = make_constraints(
fake_mode,
m,
combined_args,
dynamic_shapes,
0,
)
Expand Down
45 changes: 23 additions & 22 deletions torch/_export/non_strict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch._guards import Source
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.export import Constraint
from torch.export.dynamic_shapes import _Dim
from torch.export.dynamic_shapes import _tree_map
from torch.export.graph_signature import CustomObjArgument
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
Expand All @@ -30,7 +30,6 @@
KeyPath,
MappingKey,
SequenceKey,
tree_flatten,
tree_map_with_path,
)

Expand Down Expand Up @@ -180,25 +179,17 @@ def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes):


def _flatten_dynamic_shapes(
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]]
):
def _is_dynamic_shape_leaf(x):
if isinstance(x, dict):
x = list(x.values())
return x is None or all(isinstance(y, (_Dim, int)) or y is None for y in x)

if isinstance(dynamic_shapes, (list, tuple)):
flat_dynamic_shapes = []
for item in dynamic_shapes:
flat_shapes, _ = tree_flatten(
dynamic_shapes, is_leaf=_is_dynamic_shape_leaf
)
flat_dynamic_shapes += flat_shapes
else:
flat_dynamic_shapes, _ = tree_flatten(
dynamic_shapes, is_leaf=_is_dynamic_shape_leaf
)
return flat_dynamic_shapes
combined_args: Dict[str, Any],
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]],
) -> List[Any]:
flat_shapes = []

def _tree_map_helper(t, shape):
nonlocal flat_shapes
flat_shapes.append(shape)

_tree_map(_tree_map_helper, combined_args, dynamic_shapes)
return flat_shapes


def produce_guards_and_solve_constraints(
Expand Down Expand Up @@ -260,6 +251,7 @@ def produce_guards_and_solve_constraints(
def make_constraints(
fake_mode: FakeTensorMode,
gm: torch.fx.GraphModule,
combined_args: Dict[str, Any],
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
num_lifted_inputs: int,
):
Expand All @@ -280,7 +272,16 @@ def make_constraints(
if not dynamic_shapes:
return range_constraints

flat_dynamic_shapes = _flatten_dynamic_shapes(dynamic_shapes)
# get individual dynamic shapes spec for each input
if not isinstance(dynamic_shapes, dict):
assert isinstance(dynamic_shapes, (tuple, list))
combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes)

# check number of shapes vs. number of inputs
num_placeholders = [node.op == "placeholder" for node in gm.graph.nodes].count(True)
assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs

input_dims = defaultdict(list)
free_symbols = set()
for input_index, node in enumerate(gm.graph.nodes):
Expand Down
5 changes: 5 additions & 0 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from torch._guards import detect_fake_mode
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch._utils_internal import log_export_usage
from torch.export.dynamic_shapes import _combine_args
from torch.export.exported_program import OutputKind
from torch.fx._utils import first_call_function_nn_module_stack
from torch.fx.experimental.symbolic_shapes import (
Expand Down Expand Up @@ -1061,9 +1062,11 @@ def forward(self, *args, **kwargs):
except (ConstraintViolationError, ValueRangeError) as e:
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: TRY200

combined_args = _combine_args(mod, args, kwargs)
range_constraints = make_constraints(
fake_mode,
ep_non_strict.gm,
combined_args,
dynamic_shapes,
num_lifted,
)
Expand Down Expand Up @@ -1269,9 +1272,11 @@ def forward(self, *args, **kwargs):
),
len(export_graph_signature.input_specs),
)
combined_args = _combine_args(mod, args, kwargs)
range_constraints = make_constraints(
dynamo_fake_mode,
gm,
combined_args,
dynamic_shapes,
num_lifted,
)
Expand Down