-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement native support for float inputs in Dynamo and ShapeEnv #125325
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125325
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 4 PendingAs of commit 0c34153 with merge base ee00349 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Signed-off-by: Edward Z. Yang <[email protected]> ghstack-source-id: db67345ea00a18425f488094fdb404b9f17a51d3 Pull Request resolved: #125325
Signed-off-by: Edward Z. Yang <[email protected]> ghstack-source-id: 01504d0772c645577ff1f081ccca67c609e2b79b Pull Request resolved: #125325
torch/_dynamo/variables/builder.py
Outdated
source=self.source, | ||
) | ||
else: | ||
raise AssertionError(f"unrecognized {type(value)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise AssertionError(f"unrecognized {type(value)}") | |
raise TypeError(f"unrecognized {type(value)}") |
Signed-off-by: Edward Z. Yang <[email protected]> ghstack-source-id: ba18d119027f02f5315fae1f1a7e17a71a08b194 Pull Request resolved: #125325
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test failures?
I guess this causes a bunch of wobbles, because there are tests where we're returning a float SymNodeVariable today, and the behavior here has changed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens when you mix SymInts with SymFloats? We know that there are a number of SymPy limitations with floats. What's the plan to handle those?
Overall the PR looks good though.
if symbol_is_type( | ||
x, | ||
( | ||
SymT.UNBACKED_INT, | ||
SymT.SIZE, | ||
SymT.PRECOMPUTED_SIZE, | ||
), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, I was fiddling around with adding FLOAT to the test here, I eventually decided to get rid of it but the formatting change stuck lol
This: #122823 Basically, if we hand implement dump custom functions for all floating point functions, there's no way for sympy to screw us over because no nontrivial reasoning is supported anyway. Sympy is just a crappy AST container for us to use that's convenient. |
Signed-off-by: Edward Z. Yang <[email protected]> ghstack-source-id: 13b64fc819ca4ddc82a1d71be687de4c535fde70 Pull Request resolved: pytorch#125325
Signed-off-by: Edward Z. Yang <[email protected]> ghstack-source-id: 10ac56aaccdb98939489101d9237ee381ff3786e Pull Request resolved: #125325
All CR addressed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great!
# NB: We don't expect you to actually ever generate guards against this | ||
# source, it is ephemeral |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we throw somewhere inside the class to make sure this is the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is inadvisable. The only place we can throw the assertion is in name(), but for debugging purposes you want to always be able to query the name of a source.
|
||
self._tensor_var = SourcelessBuilder.create( | ||
tx, torch.scalar_tensor | ||
).call_function(tx, [self], {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we pass here device="cpu"
? My concern is that, what if this is being run in a context where a global device is set, like we recommend in the case users want to execute their NumPy code on GPU?
OTOH, I'm not completely sold on the idea, as there may be other devices that don't play ball with CPU scalars?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remind me again how the Numpy code on GPU stuff works
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In particular, see the code:
@torch.compile
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
prod = X[:, :, None] * Y[:, None, :]
print("oops, a graph break!")
return np.sum(prod, axis=(-2, -1))
X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")
with torch.device("cuda"):
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: New commits were pushed while merging. Please rerun the merge command. Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
…h#125915) Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#125915 Approved by: https://github.com/lezcano ghstack dependencies: pytorch#125325
…orch#125325) The big idea is that floats are treated as Tensors on input/output to the FX graph, but on the inside, we immediately call item() on the synthetic Tensor and record regular float operations on it. Canonicalization to Tensor operations will happen in a standalone FX pass. This behavior is controlled by `specialize_float` config variable when set to False. The generated graph looks like this for the test `test_unspec_float_output`: ``` def forward(self, L_x_: "f32[3]", L_y_: "f32[]"): l_x_ = L_x_ l_y_ = L_y_ # File: /data/users/ezyang/a/pytorch/test/dynamo/test_unspec.py:511 in f, code: return x + 1, y * 2 add: "f32[3]" = l_x_ + 1; l_x_ = None item: "Sym(zf0)" = l_y_.item(); l_y_ = None mul: "Sym(2*zf0)" = item * 2; item = None scalar_tensor: "f32[]" = torch.scalar_tensor(mul); mul = None return (add, scalar_tensor) ``` The ingredients: * **torch/_dynamo/variables/builder.py** When `specialize_float` is False, we wrap float literals with `wrap_symfloat`. This is an unholy mashup of `wrap_symint` and `wrap_unspecialized_primitive`. The overall strategy is that we first generate a tensor argument (because that's what we want to show up into the FX graph), but then immediately call item() on the tensor argument to get a SymNodeVariable, which we will do the rest of the tracing with. Importantly, this SymNodeVariable is backed with the source of the original float: this means we can guard on the resulting value (something we could NOT do with UnspecializedPythonVariable). This has to be done manually, because if you literally call item() on the tensor, you will end up with an unbacked float. There is a bit of copy paste from wrap_symint and wrap_unspecialized_primitive which we can try to factor out, but this really is its own thing and you should review every line of code in the function. * **torch/fx/experimental/symbolic_shapes.py** We now can generate guards on float inputs, and these guards are handled inside of ShapeEnv. So we need to be able to allocate (backed!) float symbols, and produce guards for them. Fairly straightforward generalization. * **torch/_dynamo/codegen.py** I also need to maintain the invariant that there are no float outputs to the FX graph. I chose to do this at codegen time. When we detect a SymNodeVariable on the return stack for a float, we on the fly convert it (via `as_tensor`) to a TensorVariable, which is the true output. We then special case the output bytecode to call item() on it again. The tensor conversion is memoized on SymNodeVariable since we typically run the code generation process twice. Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#125325 Approved by: https://github.com/lezcano, https://github.com/jansel
…h#125915) Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#125915 Approved by: https://github.com/lezcano ghstack dependencies: pytorch#125325
Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: #126074 Approved by: https://github.com/lezcano ghstack dependencies: #125325, #125915
Stack from ghstack (oldest at bottom):
The big idea is that floats are treated as Tensors on input/output to the FX graph, but on the inside, we immediately call item() on the synthetic Tensor and record regular float operations on it. Canonicalization to Tensor operations will happen in a standalone FX pass. This behavior is controlled by
specialize_float
config variable when set to False.The generated graph looks like this for the test
test_unspec_float_output
:The ingredients:
specialize_float
is False, we wrap float literals withwrap_symfloat
. This is an unholy mashup ofwrap_symint
andwrap_unspecialized_primitive
. The overall strategy is that we first generate a tensor argument (because that's what we want to show up into the FX graph), but then immediately call item() on the tensor argument to get a SymNodeVariable, which we will do the rest of the tracing with. Importantly, this SymNodeVariable is backed with the source of the original float: this means we can guard on the resulting value (something we could NOT do with UnspecializedPythonVariable). This has to be done manually, because if you literally call item() on the tensor, you will end up with an unbacked float. There is a bit of copy paste from wrap_symint and wrap_unspecialized_primitive which we can try to factor out, but this really is its own thing and you should review every line of code in the function.as_tensor
) to a TensorVariable, which is the true output. We then special case the output bytecode to call item() on it again. The tensor conversion is memoized on SymNodeVariable since we typically run the code generation process twice.Signed-off-by: Edward Z. Yang [email protected]
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang