forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This provides utilities for creating and querying properties on sympy.Symbol. I want to use this refactor to get a better handle on how the 's' prefix is being used in Inductor. To start, I only do symbolic_shapes code because that's what I'm familiar with. Signed-off-by: Edward Z. Yang <[email protected]> ghstack-source-id: 84d52f217a07e983db1adcf4acba4dd0c28d7c9d Pull Request resolved: pytorch#125395
- Loading branch information
Showing
6 changed files
with
68 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
""" | ||
This file contains canonical definitions for our symbol naming conventions, | ||
across torch.fx.experimental.symbolic_shapes and torch._inductor. The | ||
intention is: | ||
1. To make it easily greppable where all the sites we use a prefix are | ||
2. Make it possible to easily tell if we can introduce a new prefix without | ||
introducing a conflict | ||
You can occasionally test if prefixes have been hardcoded by renaming prefixes | ||
in this file and seeing what breaks. | ||
""" | ||
|
||
from enum import auto, Enum | ||
from typing import Sequence, Union | ||
|
||
import sympy | ||
|
||
|
||
class SymT(Enum): | ||
SIZE = auto() | ||
UNBACKED_INT = auto() | ||
UNBACKED_FLOAT = auto() | ||
|
||
|
||
# Invariant: there must not be a prefix which is a prefix of another string, | ||
# as this introduces ambiguity | ||
prefix_str = { | ||
SymT.SIZE: "s", # integer | ||
SymT.UNBACKED_INT: "u", # integer | ||
SymT.UNBACKED_FLOAT: "f", | ||
} | ||
|
||
|
||
def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol: | ||
# TODO: maybe put the assumptions here directly | ||
return sympy.Symbol(f"{prefix_str[prefix]}{idx}", **kwargs) | ||
|
||
|
||
def symbol_is_type(sym: sympy.Symbol, prefix: Union[SymT, Sequence[SymT]]) -> bool: | ||
if isinstance(prefix, SymT): | ||
return sym.name.startswith(prefix_str[prefix]) | ||
else: | ||
return sym.name.startswith(tuple(prefix_str[p] for p in prefix)) |