Skip to content

Commit

Permalink
Introduce torch.utils._sympy.symbol
Browse files Browse the repository at this point in the history
This provides utilities for creating and querying properties on
sympy.Symbol.  I want to use this refactor to get a better handle on how
the 's' prefix is being used in Inductor.  To start, I only do
symbolic_shapes code because that's what I'm familiar with.

Signed-off-by: Edward Z. Yang <[email protected]>

ghstack-source-id: 84d52f217a07e983db1adcf4acba4dd0c28d7c9d
Pull Request resolved: pytorch#125395
  • Loading branch information
ezyang committed May 3, 2024
1 parent 79af814 commit 43662e6
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 17 deletions.
4 changes: 3 additions & 1 deletion test/inductor/test_compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,11 +1516,13 @@ def wrap_test_class(orig_cls):
elif name.startswith("test_"):
dct[name] = make_wrapped(fn)

return type(
cls = type(
orig_cls.__name__ + "WithCompiledAutograd",
orig_cls.__bases__,
dct,
)
cls.__file__ = __file__
return cls


# These groups of tests aren't supported yet
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch.fx
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
from torch.utils._sympy.symbol import symbol_is_type, SymT
from torch.utils._sympy.value_ranges import ValueRanges

from .. import config, metrics
Expand Down Expand Up @@ -1682,7 +1683,8 @@ def rename_indexing(self, index) -> sympy.Expr:
replacements = {
x: self.args.size(x)
for x in sorted_symbols
if x.name.startswith(("s", "u", "ps"))
if symbol_is_type(x, (SymT.UNBACKED_INT, SymT.SIZE))
or x.name.startswith("ps")
}
return sympy_subs(index, replacements)

Expand Down
5 changes: 4 additions & 1 deletion torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
from torch._prims_common import is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import symbol_is_type, SymT
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._triton import has_triton_package

Expand Down Expand Up @@ -1671,7 +1672,9 @@ def indexing(
# indirect indexing
cse_var = self.cse.varname_map[var.name]
mask_vars.update(cse_var.mask_vars)
elif var.name.startswith(("s", "ps", "i", "u")):
elif var.name.startswith(("ps", "i")) or symbol_is_type(
var, (SymT.UNBACKED_INT, SymT.SIZE)
):
pass
else:
# var is one of xN, yN or rN
Expand Down
8 changes: 4 additions & 4 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,8 @@ def trunc(x):

@register_lowering(aten.expand, type_promotion_kind=None)
def expand(x, sizes):
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols

(x,) = promote_constants([x])
if isinstance(x, ir.BaseConstant):
return ExpandView.create(x, tuple(sizes))
Expand All @@ -839,15 +841,13 @@ def expand(x, sizes):
if tuple(x.get_size()) == tuple(sizes):
return x

if not any(V.graph.sizevars.shape_env.is_unbacked_symint(s) for s in x.get_size()):
if not free_unbacked_symbols(x.get_size()):
x_size_product = V.graph.sizevars.size_hint(sympy_product(x.get_size()))
# TODO: It would be better to realize the input if any of its sizes
# are unbacked, because typically the size will be non-zero. However,
# this cannot be done directly as below as we'll choke on the size_hint
# here
if x_size_product > 0 and not any(
V.graph.sizevars.shape_env.is_unbacked_symint(s) for s in sizes
):
if x_size_product > 0 and not free_unbacked_symbols(sizes):
# maybe realize input before broadcasting it
x.mark_reuse(
V.graph.sizevars.size_hint(sympy_product(sizes)) // x_size_product
Expand Down
20 changes: 10 additions & 10 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from torch._utils_internal import signpost_event
from torch._subclasses.meta_utils import is_sparse_any
import torch.utils._pytree as pytree
from torch.utils._sympy.symbol import SymT, make_symbol, symbol_is_type

from torch._logging import LazyString

Expand Down Expand Up @@ -439,7 +440,7 @@ def has_free_symbols(val: Union[SymInt, torch.Tensor]) -> bool:
# Like free_symbols, but filtered to only report unbacked symbols
def free_unbacked_symbols(x):
# NB: keep synced with is_unbacked_symint
return {s for s in free_symbols(x) if s.name.startswith(("u", "f"))}
return {s for s in free_symbols(x) if symbol_is_type(s, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT))}

# WARNING: Don't use this on Dynamo produced graphs, they don't have meta
# setup!
Expand Down Expand Up @@ -1661,7 +1662,7 @@ def _reduce_congruences(self):
# We are given a congruence of the form base % divisor == 0 with a free variable s. So:
# - we transform this into an equation of the form base = divisor * tmp;
# - we solve this equation for s to get a linear solution with free variable tmp.
tmp = sympy.Symbol("tmp", integer=True)
tmp = sympy.Symbol("reduce_congruences_tmp", integer=True)
symbol, solution = sympy.solve_linear(base - divisor * tmp, symbols=[s])
# See https://docs.sympy.org/latest/modules/solvers/solvers.html#sympy.solvers.solvers.solve_linear
# for how to interpret the results.
Expand Down Expand Up @@ -3025,7 +3026,7 @@ def _log_create_unbacked_symbol(self, prefix: str, symbol, vr: ValueRanges):
def create_unbacked_symfloat(self):
"""Create a symbolic float without a hint value
"""
symbol: sympy.Symbol = sympy.Symbol(f"f{next(self.unbacked_symfloat_counter)}")
symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_FLOAT, next(self.unbacked_symfloat_counter))
self.counter["create_unbacked_symbol"] += 1
if not self._ignore_fresh_unbacked_symbols_tls():
self.pending_fresh_unbacked_symbols.append(symbol)
Expand All @@ -3043,7 +3044,7 @@ def create_unbacked_symfloat(self):
def create_unbacked_symint(self):
"""Create a symbolic integer without a hint value
"""
symbol: sympy.Symbol = sympy.Symbol(f"u{next(self.unbacked_symint_counter)}", integer=True)
symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True)
if not self._ignore_fresh_unbacked_symbols_tls():
self.pending_fresh_unbacked_symbols.append(symbol)
self.counter["create_unbacked_symbol"] += 1
Expand All @@ -3060,14 +3061,13 @@ def create_unbacked_symint(self):
def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool:
"""Check if a sympy symbol matches the naming convention for unbacked symbols
"""
# NB: keep synced with free_unbacked_symbols
return str(symbol).startswith("u")
return symbol_is_type(symbol, SymT.UNBACKED_INT)

@record_shapeenv_event()
def create_unbacked_symbool(self):
"""Create a symbolic boolean without a hint value
"""
symbol: sympy.Symbol = sympy.Symbol(f"u{next(self.unbacked_symint_counter)}", integer=True)
symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True)
if not self._ignore_fresh_unbacked_symbols_tls():
self.pending_fresh_unbacked_symbols.append(symbol)
self.counter["create_unbacked_symbol"] += 1
Expand Down Expand Up @@ -3179,7 +3179,7 @@ def create_symbol(
# If we're not duck shaping, we always create a new symbol
# Even if we're duck shaping, if we haven't seen this particular
# value before, we also create a new symbol
sympy_expr = sympy.Symbol(f"s{len(self.var_to_val)}", positive=positive, integer=True)
sympy_expr = make_symbol(SymT.SIZE, len(self.var_to_val), positive=positive, integer=True)
# We always associate vars to vals
if isinstance(val, int):
self.var_to_val[sympy_expr] = sympy.Integer(val)
Expand Down Expand Up @@ -4094,7 +4094,7 @@ def _maybe_evaluate_static(
# we have to increase it by offset (and conversely, the new
# variables have to have their value range bounds adjusted as
# well)
s = sympy.Symbol(f"shape_{idx}", positive=True, integer=True)
s = sympy.Symbol(f"evaluate_static_shape_{idx}", positive=True, integer=True)

# Note:
# Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers.
Expand Down Expand Up @@ -4896,7 +4896,7 @@ def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None):
stack = CapturedTraceback.extract(skip=1)
ra = RuntimeAssert(expr, msg, stack)
# TODO: Do this in a way that is less janky than int(s.name[1:])
cands = sorted([s for s in expr.free_symbols if s.name.startswith("u")], key=lambda s: int(s.name[1:]))
cands = sorted((s for s in expr.free_symbols if symbol_is_type(s, SymT.UNBACKED_INT)), key=lambda s: int(s.name[1:]))
# Is None when prefer_deferred_runtime_asserts_over_guards=True
# and the guard in question has no unbacked SymInts in front
ix = cands[-1] if cands else None
Expand Down
44 changes: 44 additions & 0 deletions torch/utils/_sympy/symbol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
This file contains canonical definitions for our symbol naming conventions,
across torch.fx.experimental.symbolic_shapes and torch._inductor. The
intention is:
1. To make it easily greppable where all the sites we use a prefix are
2. Make it possible to easily tell if we can introduce a new prefix without
introducing a conflict
You can occasionally test if prefixes have been hardcoded by renaming prefixes
in this file and seeing what breaks.
"""

from enum import auto, Enum
from typing import Sequence, Union

import sympy


class SymT(Enum):
SIZE = auto()
UNBACKED_INT = auto()
UNBACKED_FLOAT = auto()


# Invariant: there must not be a prefix which is a prefix of another string,
# as this introduces ambiguity
prefix_str = {
SymT.SIZE: "s", # integer
SymT.UNBACKED_INT: "u", # integer
SymT.UNBACKED_FLOAT: "f",
}


def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol:
# TODO: maybe put the assumptions here directly
return sympy.Symbol(f"{prefix_str[prefix]}{idx}", **kwargs)


def symbol_is_type(sym: sympy.Symbol, prefix: Union[SymT, Sequence[SymT]]) -> bool:
if isinstance(prefix, SymT):
return sym.name.startswith(prefix_str[prefix])
else:
return sym.name.startswith(tuple(prefix_str[p] for p in prefix))

0 comments on commit 43662e6

Please sign in to comment.