Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce torch.utils._sympy.symbol #125395

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch.fx
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
from torch.utils._sympy.symbol import symbol_is_type, SymT
from torch.utils._sympy.value_ranges import ValueRanges

from .. import config, metrics
Expand Down Expand Up @@ -1682,7 +1683,8 @@ def rename_indexing(self, index) -> sympy.Expr:
replacements = {
x: self.args.size(x)
for x in sorted_symbols
if x.name.startswith(("s", "u", "ps"))
if symbol_is_type(x, (SymT.UNBACKED_INT, SymT.SIZE))
or x.name.startswith("ps")
or (x.name.startswith("i") and not x.name.startswith("idx"))
}
return sympy_subs(index, replacements)
Expand Down
5 changes: 4 additions & 1 deletion torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from torch._inductor.runtime.hints import AutotuneHint
from torch._prims_common import is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import symbol_is_type, SymT
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._triton import has_triton_package

Expand Down Expand Up @@ -1671,7 +1672,9 @@ def indexing(
# indirect indexing
cse_var = self.cse.varname_map[var.name]
mask_vars.update(cse_var.mask_vars)
elif var.name.startswith(("s", "ps", "i", "u")):
elif var.name.startswith(("ps", "i")) or symbol_is_type(
var, (SymT.UNBACKED_INT, SymT.SIZE)
):
pass
else:
# var is one of xN, yN or rN
Expand Down
20 changes: 10 additions & 10 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from torch._utils_internal import signpost_event
from torch._subclasses.meta_utils import is_sparse_any
import torch.utils._pytree as pytree
from torch.utils._sympy.symbol import SymT, make_symbol, symbol_is_type

from torch._logging import LazyString

Expand Down Expand Up @@ -439,7 +440,7 @@ def has_free_symbols(val: Union[SymInt, torch.Tensor]) -> bool:
# Like free_symbols, but filtered to only report unbacked symbols
def free_unbacked_symbols(x):
# NB: keep synced with is_unbacked_symint
return {s for s in free_symbols(x) if s.name.startswith(("u", "f"))}
return {s for s in free_symbols(x) if symbol_is_type(s, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT))}

# WARNING: Don't use this on Dynamo produced graphs, they don't have meta
# setup!
Expand Down Expand Up @@ -1644,7 +1645,7 @@ def _reduce_congruences(self):
# We are given a congruence of the form base % divisor == 0 with a free variable s. So:
# - we transform this into an equation of the form base = divisor * tmp;
# - we solve this equation for s to get a linear solution with free variable tmp.
tmp = sympy.Symbol("tmp", integer=True)
tmp = sympy.Symbol("reduce_congruences_tmp", integer=True)
symbol, solution = sympy.solve_linear(base - divisor * tmp, symbols=[s])
# See https://docs.sympy.org/latest/modules/solvers/solvers.html#sympy.solvers.solvers.solve_linear
# for how to interpret the results.
Expand Down Expand Up @@ -2998,7 +2999,7 @@ def _log_create_unbacked_symbol(self, prefix: str, symbol, vr: ValueRanges):
def create_unbacked_symfloat(self):
"""Create a symbolic float without a hint value
"""
symbol: sympy.Symbol = sympy.Symbol(f"f{next(self.unbacked_symfloat_counter)}")
symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_FLOAT, next(self.unbacked_symfloat_counter))
self.counter["create_unbacked_symbol"] += 1
if not self._ignore_fresh_unbacked_symbols_tls():
self.pending_fresh_unbacked_symbols.append(symbol)
Expand All @@ -3016,7 +3017,7 @@ def create_unbacked_symfloat(self):
def create_unbacked_symint(self):
"""Create a symbolic integer without a hint value
"""
symbol: sympy.Symbol = sympy.Symbol(f"u{next(self.unbacked_symint_counter)}", integer=True)
symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True)
if not self._ignore_fresh_unbacked_symbols_tls():
self.pending_fresh_unbacked_symbols.append(symbol)
self.counter["create_unbacked_symbol"] += 1
Expand All @@ -3033,14 +3034,13 @@ def create_unbacked_symint(self):
def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool:
"""Check if a sympy symbol matches the naming convention for unbacked symbols
"""
# NB: keep synced with free_unbacked_symbols
return str(symbol).startswith("u")
return symbol_is_type(symbol, SymT.UNBACKED_INT)

@record_shapeenv_event()
def create_unbacked_symbool(self):
"""Create a symbolic boolean without a hint value
"""
symbol: sympy.Symbol = sympy.Symbol(f"u{next(self.unbacked_symint_counter)}", integer=True)
symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True)
if not self._ignore_fresh_unbacked_symbols_tls():
self.pending_fresh_unbacked_symbols.append(symbol)
self.counter["create_unbacked_symbol"] += 1
Expand Down Expand Up @@ -3152,7 +3152,7 @@ def create_symbol(
# If we're not duck shaping, we always create a new symbol
# Even if we're duck shaping, if we haven't seen this particular
# value before, we also create a new symbol
sympy_expr = sympy.Symbol(f"s{len(self.var_to_val)}", positive=positive, integer=True)
sympy_expr = make_symbol(SymT.SIZE, len(self.var_to_val), positive=positive, integer=True)
# We always associate vars to vals
if isinstance(val, int):
self.var_to_val[sympy_expr] = sympy.Integer(val)
Expand Down Expand Up @@ -4063,7 +4063,7 @@ def _maybe_evaluate_static(
# we have to increase it by offset (and conversely, the new
# variables have to have their value range bounds adjusted as
# well)
s = sympy.Symbol(f"shape_{idx}", positive=True, integer=True)
s = sympy.Symbol(f"evaluate_static_shape_{idx}", positive=True, integer=True)

# Note:
# Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers.
Expand Down Expand Up @@ -4847,7 +4847,7 @@ def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None):
stack = CapturedTraceback.extract(skip=1)
ra = RuntimeAssert(expr, msg, stack)
# TODO: Do this in a way that is less janky than int(s.name[1:])
cands = sorted([s for s in expr.free_symbols if s.name.startswith("u")], key=lambda s: int(s.name[1:]))
cands = sorted([s for s in expr.free_symbols if symbol_is_type(s, SymT.UNBACKED_INT)], key=lambda s: int(s.name[1:]))
ezyang marked this conversation as resolved.
Show resolved Hide resolved
# Is None when prefer_deferred_runtime_asserts_over_guards=True
# and the guard in question has no unbacked SymInts in front
ix = cands[-1] if cands else None
Expand Down
44 changes: 44 additions & 0 deletions torch/utils/_sympy/symbol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
This file contains canonical definitions for our symbol naming conventions,
across torch.fx.experimental.symbolic_shapes and torch._inductor. The
intention is:

1. To make it easily greppable where all the sites we use a prefix are
2. Make it possible to easily tell if we can introduce a new prefix without
introducing a conflict

You can occasionally test if prefixes have been hardcoded by renaming prefixes
in this file and seeing what breaks.
"""

from enum import auto, Enum
from typing import Sequence, Union

import sympy


class SymT(Enum):
SIZE = auto()
UNBACKED_INT = auto()
UNBACKED_FLOAT = auto()


# Invariant: there must not be a prefix which is a prefix of another string,
# as this introduces ambiguity
prefix_str = {
SymT.SIZE: "s", # integer
SymT.UNBACKED_INT: "u", # integer
SymT.UNBACKED_FLOAT: "f",
}


def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol:
# TODO: maybe put the assumptions here directly
return sympy.Symbol(f"{prefix}{idx}", **kwargs)


def symbol_is_type(sym: sympy.Symbol, prefix: Union[SymT, Sequence[SymT]]) -> bool:
if isinstance(prefix, SymT):
return sym.name.startswith(prefix_str[prefix])
else:
return sym.name.startswith(tuple(prefix_str[p] for p in prefix))