Skip to content

Commit

Permalink
Split wrap_symint out of wrap_unspecialized_primitive (#125483)
Browse files Browse the repository at this point in the history
While there are some similarities, they are also quite different (one
handles Numpy numbers while the other handles ints.  I am also going to
add a wrap_symfloat soon which will do even more different behavior.
So split these out for clarity.

Signed-off-by: Edward Z. Yang <[email protected]>

Pull Request resolved: #125483
Approved by: https://github.com/lezcano
ghstack dependencies: #125395, #125419
  • Loading branch information
ezyang authored and pytorchmergebot committed May 5, 2024
1 parent 10f6735 commit 617e473
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 112 deletions.
270 changes: 158 additions & 112 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@ def wrap_literal(self, value):
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value, source=self.source)
else:
return self.wrap_unspecialized_primitive(value)
return self.wrap_symint(value)
else:
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value)
Expand Down Expand Up @@ -1302,128 +1302,174 @@ def wrap_numpy_ndarray(self, value):

return numpy_ndarray_variable

def wrap_unspecialized_primitive(self, value):
def wrap_symint(self, value):
assert type(value) is int

if self.name in self.tx.output.unspec_variable_map:
return self.tx.output.unspec_variable_map[self.name]

shape_env = self.tx.output.shape_env
if TracingContext.get().force_unspec_int_unbacked_size_like:
wrapped_value = shape_env.create_unbacked_symint()
_constrain_range_for_size(wrapped_value)
self.tx.output.bound_symbols.add(wrapped_value.node.expr)
self.tx.output.tracked_fakes.append(
TrackedFake(wrapped_value, self.source, None)
)

# NB: We do not do float. For motivation, see
# https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit
# but the general idea is that we generate kernels that can
# take unspecialized floats and use them in sizevar computation
elif not is_constant_source(self.get_source()):
if torch._dynamo.config.specialize_int:
# If specialize_int is False, also return
# a constant (but this should have been handled
# in the caller, TBH)
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value, source=self.source)

name = self.source.name()
if name not in self.tx.output.frame_state:
# Note - this essentially means that if this name gets reused as a tensor,
# it will start fully dynamic. That should always be a safe option, and not awfully inefficient.
# Alternatively, if we want to improve pef here, we can add a third state of unset, but I am not
# sure that is necessary for now.
frame_state_entry = FrameStateSizeEntry(scalar=value, size=None)
else:
frame_state_entry = self.tx.output.frame_state[name]
if frame_state_entry.scalar != value:
log.debug(
"automatic dynamic int %s val %s != %s",
name,
value,
frame_state_entry.scalar,
)
frame_state_entry.scalar = None
self.tx.output.frame_state[name] = frame_state_entry

# TODO: This should be dynamic, as we in general do not
# know if bare integers are actually going to be sizevars
# and it is inappropriate to eagerly duck size them with
# real sizevars
if (
config.automatic_dynamic_shapes and frame_state_entry.scalar is None
) or not config.assume_static_by_default:
dynamic_dim = DimDynamic.DYNAMIC
else: # assume_static_by_default
# TODO: dynamic_dim = DimDynamic.STATIC should work but
# for some reason it doesn't
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value)

wrapped_value = shape_env.create_unspecified_symint_and_symbol(
value,
source=self.source,
dynamic_dim=dynamic_dim,
)
self.tx.output.bound_symbols.add(wrapped_value.node.expr)

self.tx.output.tracked_fakes.append(
TrackedFake(wrapped_value, self.source, None)
)
else:
shape_env = self.tx.output.shape_env
if TracingContext.get().force_unspec_int_unbacked_size_like and isinstance(
value, int
):
wrapped_value = shape_env.create_unbacked_symint()
_constrain_range_for_size(wrapped_value)
self.tx.output.bound_symbols.add(wrapped_value.node.expr)
self.tx.output.tracked_fakes.append(
TrackedFake(wrapped_value, self.source, None)
)
assert is_constant_source(self.get_source())
# TODO: Do I actually need guard for constant source?
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value, source=self.source)

# NB: We do not do float. For motivation, see
# https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit
# but the general idea is that we generate kernels that can
# take unspecialized floats and use them in sizevar computation
elif (
type(value) is int
and not is_constant_source(self.get_source())
and not isinstance(self.get_source(), RandomValueSource)
):
if torch._dynamo.config.specialize_int:
# If specialize_int is False, also return
# a constant (but this should have been handled
# in the caller, TBH)
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value, source=self.source)

name = self.source.name()
if name not in self.tx.output.frame_state:
# Note - this essentially means that if this name gets reused as a tensor,
# it will start fully dynamic. That should always be a safe option, and not awfully inefficient.
# Alternatively, if we want to improve pef here, we can add a third state of unset, but I am not
# sure that is necessary for now.
frame_state_entry = FrameStateSizeEntry(scalar=value, size=None)
else:
frame_state_entry = self.tx.output.frame_state[name]
if frame_state_entry.scalar != value:
log.debug(
"automatic dynamic int %s val %s != %s",
name,
value,
frame_state_entry.scalar,
)
frame_state_entry.scalar = None
self.tx.output.frame_state[name] = frame_state_entry

# TODO: This should be dynamic, as we in general do not
# know if bare integers are actually going to be sizevars
# and it is inappropriate to eagerly duck size them with
# real sizevars
if (
config.automatic_dynamic_shapes and frame_state_entry.scalar is None
) or not config.assume_static_by_default:
dynamic_dim = DimDynamic.DYNAMIC
else: # assume_static_by_default
# TODO: dynamic_dim = DimDynamic.STATIC should work but
# for some reason it doesn't
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value)

wrapped_value = shape_env.create_unspecified_symint_and_symbol(
value,
source=self.source,
dynamic_dim=dynamic_dim,
assert not isinstance(self.get_source(), RandomValueSource)
install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))

options = {"source": self.get_source()}

proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
type(wrapped_value),
source=self.get_source(),
)

unspec_var = wrap_fx_proxy_cls(
SymNodeVariable, # NB: this doesn't actually do anything
tx=self.tx,
proxy=proxy,
example_value=wrapped_value,
**options,
)
self.tx.output.unspec_variable_map[self.name] = unspec_var

if not is_constant_source(self.get_source()):
if self.tx.export and not isinstance(self.get_source(), LocalSource):
raise AssertionError(
f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}"
)
self.tx.output.bound_symbols.add(wrapped_value.node.expr)

self.tx.output.tracked_fakes.append(
TrackedFake(wrapped_value, self.source, None)
example_value = unspec_var.proxy.node.meta["example_value"]

proxy.node.meta["grapharg"] = GraphArg(
self.get_source(),
wrapped_value,
is_unspecialized=False,
fake_tensor=None,
is_tensor=False,
example_strong_ref=wrapped_value,
)

return unspec_var

def wrap_unspecialized_primitive(self, value):
if self.name in self.tx.output.unspec_variable_map:
return self.tx.output.unspec_variable_map[self.name]

wrapped_value = torch.tensor(value)
if not isinstance(self.get_source(), RandomValueSource):
install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))

options = {"source": self.get_source()}
options.update({"raw_value": value})

proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
type(wrapped_value),
source=self.get_source(),
)

unspec_var = wrap_fx_proxy_cls(
UnspecializedPythonVariable,
tx=self.tx,
proxy=proxy,
example_value=wrapped_value,
**options,
)
self.tx.output.unspec_variable_map[self.name] = unspec_var
if not is_constant_source(self.get_source()):
if self.tx.export and not isinstance(self.get_source(), LocalSource):
raise AssertionError(
f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}"
)
fake_tensor_value = None
if isinstance(unspec_var, ConstantVariable):
# TODO: when can this happen?
example_value = unspec_var.value
else:
wrapped_value = torch.tensor(value)
if not isinstance(self.get_source(), RandomValueSource):
install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
options = {"source": self.get_source()}
if isinstance(wrapped_value, torch.Tensor):
options.update({"raw_value": value})

proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
type(wrapped_value),
source=self.get_source(),
)
example_value = unspec_var.proxy.node.meta["example_value"]
assert is_fake(example_value)

unspec_var = wrap_fx_proxy_cls(
UnspecializedPythonVariable,
tx=self.tx,
proxy=proxy,
example_value=wrapped_value,
**options,
fake_tensor_value = example_value
assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
"({self.tx.fake_mode}) from InstructionTranslator"
)
self.tx.output.unspec_variable_map[self.name] = unspec_var
if not is_constant_source(self.get_source()):
if self.tx.export and not isinstance(self.get_source(), LocalSource):
raise AssertionError(
f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}"
)
fake_tensor_value = None
if isinstance(unspec_var, ConstantVariable):
example_value = unspec_var.value
else:
example_value = unspec_var.proxy.node.meta["example_value"]
if is_fake(example_value):
fake_tensor_value = example_value
assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
"({self.tx.fake_mode}) from InstructionTranslator"
)

proxy.node.meta["grapharg"] = GraphArg(
self.get_source(),
wrapped_value,
isinstance(wrapped_value, torch.Tensor),
fake_tensor_value,
is_tensor=False,
example_strong_ref=wrapped_value,
)
return unspec_var
proxy.node.meta["grapharg"] = GraphArg(
self.get_source(),
wrapped_value,
is_unspecialized=True,
fake_tensor=fake_tensor_value,
is_tensor=False,
example_strong_ref=wrapped_value,
)
return unspec_var


def _dataclasses_fields_lambda(obj):
Expand Down
3 changes: 3 additions & 0 deletions torch/_dynamo/variables/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,9 @@ def call_function(
example_value = self.value(*args, **kwargs)
source = RandomValueSource(random_call_index)
tx.output.random_calls.append((self.value, args, kwargs))
# TODO: arguably, this should route to wrap_symint/wrap_symfloat
# (currently hypothetical), but I'm not going to poke my hand in
# this nest for now
return VariableBuilder(tx, source).wrap_unspecialized_primitive(
example_value
)
Expand Down

0 comments on commit 617e473

Please sign in to comment.