Skip to content

Commit

Permalink
Rename is_unspecialized to pass_arg_as_tensor, add comment (#125496)
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <[email protected]>

Pull Request resolved: #125496
Approved by: https://github.com/lezcano
ghstack dependencies: #125395, #125419, #125483, #125494
  • Loading branch information
ezyang authored and pytorchmergebot committed May 5, 2024
1 parent 12da7ee commit 650a248
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
2 changes: 1 addition & 1 deletion torch/_dynamo/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def make_call_generated_code(self, fn_name: str) -> None:

graphargs = self.tx.output.graphargs
for arg in graphargs:
if arg.is_unspecialized:
if arg.pass_arg_as_tensor:
self.extend_output(
[
self.create_load_python_module(torch, True),
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def bind_symint(s, prop):
proxy.node.meta["grapharg"] = GraphArg(
prop,
s,
is_unspecialized=False,
pass_arg_as_tensor=False,
fake_tensor=None,
is_tensor=False,
)
Expand Down
34 changes: 27 additions & 7 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,23 @@ class GraphArg:
# thing to do. Probably should have example (which stores an int) and
# fake_example
_example: Union[TensorWeakRef, torch.SymInt]
is_unspecialized: bool
# When True, this indicates that this GraphArg is a Python quantity (e.g.,
# a float or int) which we pass to the FX graph as a Tensor. This
# controls how we codegen calls into the Dynamo graph: we will call
# torch.as_tensor on the quantity before passing it in.
#
# Note that we typically do not pass dynamic integers as tensors, because
# they will most frequently just be used for size computation. But this
# is a policy decision that we can change our mind on; in particular, when
# an int comes from a random number generator (e.g., random.randint), we
# DO pass it as a tensor.
#
# It's also worth noting that our current tracing rules for
# pass_arg_as_tensor as subtly broken: we just pun the variable as a
# 0d scalar Tensor and pray that the semantics are the same. Which they
# often are, but not necessarily. ezyang(May 2024) plans to fix this
# soon.
pass_arg_as_tensor: bool
fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor]
# UnspecializedPythonVariable often masquerades as a tensor.
# We MUST NOT generate shape guard code
Expand Down Expand Up @@ -233,7 +249,7 @@ def __init__(self):
super().__init__(
source=None,
_example=BackwardState(),
is_unspecialized=False,
pass_arg_as_tensor=False,
fake_tensor=None,
is_tensor=False,
)
Expand Down Expand Up @@ -955,7 +971,11 @@ def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
install_guard(*guards, skip=1)

grapharg = GraphArg(
source, value, is_unspecialized=False, fake_tensor=None, is_tensor=False
source,
value,
pass_arg_as_tensor=False,
fake_tensor=None,
is_tensor=False,
)
tensor_list_proxy.node.meta["grapharg"] = grapharg

Expand Down Expand Up @@ -1288,12 +1308,12 @@ def wrap_numpy_ndarray(self, value):
self.tx.output.input_source_to_var[source] = numpy_ndarray_variable
example_value = numpy_ndarray_variable.proxy.node.meta["example_value"]

# is_unspecialized should be true because we are wrapping a np.ndarray as argument input, and it needs to be
# pass_arg_as_tensor should be true because we are wrapping a np.ndarray as argument input, and it needs to be
# converted to a tensor.
grapharg = GraphArg(
source,
tensor_value,
is_unspecialized=True,
pass_arg_as_tensor=True,
fake_tensor=example_value,
is_tensor=True,
example_strong_ref=tensor_value,
Expand Down Expand Up @@ -1404,7 +1424,7 @@ def wrap_symint(self, value):
proxy.node.meta["grapharg"] = GraphArg(
self.get_source(),
wrapped_value,
is_unspecialized=False,
pass_arg_as_tensor=False,
fake_tensor=None,
is_tensor=False,
example_strong_ref=wrapped_value,
Expand Down Expand Up @@ -1459,7 +1479,7 @@ def wrap_unspecialized_primitive(self, value):
proxy.node.meta["grapharg"] = GraphArg(
self.get_source(),
wrapped_value,
is_unspecialized=True,
pass_arg_as_tensor=True,
fake_tensor=fake_tensor_value,
is_tensor=False,
example_strong_ref=wrapped_value,
Expand Down

0 comments on commit 650a248

Please sign in to comment.