Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement native support for float inputs in Dynamo and ShapeEnv #125325

Closed
wants to merge 8 commits into from

Conversation

ezyang
Copy link
Contributor

@ezyang ezyang commented May 1, 2024

Stack from ghstack (oldest at bottom):

The big idea is that floats are treated as Tensors on input/output to the FX graph, but on the inside, we immediately call item() on the synthetic Tensor and record regular float operations on it. Canonicalization to Tensor operations will happen in a standalone FX pass. This behavior is controlled by specialize_float config variable when set to False.

The generated graph looks like this for the test test_unspec_float_output:

 def forward(self, L_x_: "f32[3]", L_y_: "f32[]"):
     l_x_ = L_x_
     l_y_ = L_y_
     
     # File: /data/users/ezyang/a/pytorch/test/dynamo/test_unspec.py:511 in f, code: return x + 1, y * 2
     add: "f32[3]" = l_x_ + 1;  l_x_ = None
     item: "Sym(zf0)" = l_y_.item();  l_y_ = None
     mul: "Sym(2*zf0)" = item * 2;  item = None
     scalar_tensor: "f32[]" = torch.scalar_tensor(mul);  mul = None
     return (add, scalar_tensor)

The ingredients:

  • torch/_dynamo/variables/builder.py When specialize_float is False, we wrap float literals with wrap_symfloat. This is an unholy mashup of wrap_symint and wrap_unspecialized_primitive. The overall strategy is that we first generate a tensor argument (because that's what we want to show up into the FX graph), but then immediately call item() on the tensor argument to get a SymNodeVariable, which we will do the rest of the tracing with. Importantly, this SymNodeVariable is backed with the source of the original float: this means we can guard on the resulting value (something we could NOT do with UnspecializedPythonVariable). This has to be done manually, because if you literally call item() on the tensor, you will end up with an unbacked float. There is a bit of copy paste from wrap_symint and wrap_unspecialized_primitive which we can try to factor out, but this really is its own thing and you should review every line of code in the function.
  • torch/fx/experimental/symbolic_shapes.py We now can generate guards on float inputs, and these guards are handled inside of ShapeEnv. So we need to be able to allocate (backed!) float symbols, and produce guards for them. Fairly straightforward generalization.
  • torch/_dynamo/codegen.py I also need to maintain the invariant that there are no float outputs to the FX graph. I chose to do this at codegen time. When we detect a SymNodeVariable on the return stack for a float, we on the fly convert it (via as_tensor) to a TensorVariable, which is the true output. We then special case the output bytecode to call item() on it again. The tensor conversion is memoized on SymNodeVariable since we typically run the code generation process twice.

Signed-off-by: Edward Z. Yang [email protected]

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

[ghstack-poisoned]
Copy link

pytorch-bot bot commented May 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125325

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 4 Pending

As of commit 0c34153 with merge base ee00349 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ezyang added a commit that referenced this pull request May 1, 2024
Signed-off-by: Edward Z. Yang <[email protected]>

ghstack-source-id: db67345ea00a18425f488094fdb404b9f17a51d3
Pull Request resolved: #125325
@albanD albanD removed their request for review May 2, 2024 20:55
[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label May 3, 2024
ezyang added a commit that referenced this pull request May 3, 2024
Signed-off-by: Edward Z. Yang <[email protected]>

ghstack-source-id: 01504d0772c645577ff1f081ccca67c609e2b79b
Pull Request resolved: #125325
source=self.source,
)
else:
raise AssertionError(f"unrecognized {type(value)}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise AssertionError(f"unrecognized {type(value)}")
raise TypeError(f"unrecognized {type(value)}")

[ghstack-poisoned]
ezyang added a commit that referenced this pull request May 9, 2024
Signed-off-by: Edward Z. Yang <[email protected]>

ghstack-source-id: ba18d119027f02f5315fae1f1a7e17a71a08b194
Pull Request resolved: #125325
@ezyang ezyang changed the title [POC] Implement native support for float inputs in Dynamo and ShapeEnv Implement native support for float inputs in Dynamo and ShapeEnv May 9, 2024
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test failures?

@ezyang
Copy link
Contributor Author

ezyang commented May 9, 2024

I guess this causes a bunch of wobbles, because there are tests where we're returning a float SymNodeVariable today, and the behavior here has changed

[ghstack-poisoned]
[ghstack-poisoned]
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when you mix SymInts with SymFloats? We know that there are a number of SymPy limitations with floats. What's the plan to handle those?

Overall the PR looks good though.

torch/_dynamo/codegen.py Outdated Show resolved Hide resolved
torch/_dynamo/codegen.py Outdated Show resolved Hide resolved
torch/_dynamo/variables/builder.py Outdated Show resolved Hide resolved
torch/_dynamo/variables/builder.py Show resolved Hide resolved
torch/_dynamo/variables/builder.py Show resolved Hide resolved
torch/_dynamo/variables/builder.py Outdated Show resolved Hide resolved
torch/_dynamo/variables/builder.py Outdated Show resolved Hide resolved
torch/_dynamo/variables/builder.py Show resolved Hide resolved
torch/_dynamo/variables/builder.py Outdated Show resolved Hide resolved
Comment on lines +1720 to +1727
if symbol_is_type(
x,
(
SymT.UNBACKED_INT,
SymT.SIZE,
SymT.PRECOMPUTED_SIZE,
),
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, I was fiddling around with adding FLOAT to the test here, I eventually decided to get rid of it but the formatting change stuck lol

@ezyang
Copy link
Contributor Author

ezyang commented May 10, 2024

What happens when you mix SymInts with SymFloats? We know that there are a number of SymPy limitations with floats. What's the plan to handle those?

This: #122823

Basically, if we hand implement dump custom functions for all floating point functions, there's no way for sympy to screw us over because no nontrivial reasoning is supported anyway. Sympy is just a crappy AST container for us to use that's convenient.

OnlyFor pushed a commit to OnlyFor/pytorch that referenced this pull request May 11, 2024
Signed-off-by: Edward Z. Yang <[email protected]>

ghstack-source-id: 13b64fc819ca4ddc82a1d71be687de4c535fde70
Pull Request resolved: pytorch#125325
[ghstack-poisoned]
ezyang added a commit that referenced this pull request May 13, 2024
Signed-off-by: Edward Z. Yang <[email protected]>

ghstack-source-id: 10ac56aaccdb98939489101d9237ee381ff3786e
Pull Request resolved: #125325
@ezyang
Copy link
Contributor Author

ezyang commented May 13, 2024

All CR addressed

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great!

Comment on lines +563 to +564
# NB: We don't expect you to actually ever generate guards against this
# source, it is ephemeral
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we throw somewhere inside the class to make sure this is the case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is inadvisable. The only place we can throw the assertion is in name(), but for debugging purposes you want to always be able to query the name of a source.


self._tensor_var = SourcelessBuilder.create(
tx, torch.scalar_tensor
).call_function(tx, [self], {})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we pass here device="cpu"? My concern is that, what if this is being run in a context where a global device is set, like we recommend in the case users want to execute their NumPy code on GPU?

OTOH, I'm not completely sold on the idea, as there may be other devices that don't play ball with CPU scalars?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remind me again how the Numpy code on GPU stuff works

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See https://pytorch.org/docs/stable/torch.compiler_faq.html#can-i-execute-numpy-code-on-cuda-and-compute-gradients-via-torch-compile

In particular, see the code:

@torch.compile
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
    prod = X[:, :, None] * Y[:, None, :]
    print("oops, a graph break!")
    return np.sum(prod, axis=(-2, -1))

X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")

with torch.device("cuda"):
    Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"

@ezyang
Copy link
Contributor Author

ezyang commented May 13, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

@ezyang
Copy link
Contributor Author

ezyang commented May 13, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

smalltalkman pushed a commit to smalltalkman/pytorch that referenced this pull request May 14, 2024
tinglvv pushed a commit to tinglvv/pytorch that referenced this pull request May 14, 2024
…orch#125325)

The big idea is that floats are treated as Tensors on input/output to the FX graph, but on the inside, we immediately call item() on the synthetic Tensor and record regular float operations on it. Canonicalization to Tensor operations will happen in a standalone FX pass. This behavior is controlled by `specialize_float` config variable when set to False.

The generated graph looks like this for the test `test_unspec_float_output`:

```
 def forward(self, L_x_: "f32[3]", L_y_: "f32[]"):
     l_x_ = L_x_
     l_y_ = L_y_

     # File: /data/users/ezyang/a/pytorch/test/dynamo/test_unspec.py:511 in f, code: return x + 1, y * 2
     add: "f32[3]" = l_x_ + 1;  l_x_ = None
     item: "Sym(zf0)" = l_y_.item();  l_y_ = None
     mul: "Sym(2*zf0)" = item * 2;  item = None
     scalar_tensor: "f32[]" = torch.scalar_tensor(mul);  mul = None
     return (add, scalar_tensor)
```

The ingredients:

* **torch/_dynamo/variables/builder.py** When `specialize_float` is False, we wrap float literals with `wrap_symfloat`. This is an unholy mashup of `wrap_symint` and `wrap_unspecialized_primitive`. The overall strategy is that we first generate a tensor argument (because that's what we want to show up into the FX graph), but then immediately call item() on the tensor argument to get a SymNodeVariable, which we will do the rest of the tracing with.  Importantly, this SymNodeVariable is backed with the source of the original float: this means we can guard on the resulting value (something we could NOT do with UnspecializedPythonVariable). This has to be done manually, because if you literally call item() on the tensor, you will end up with an unbacked float. There is a bit of copy paste from wrap_symint and wrap_unspecialized_primitive which we can try to factor out, but this really is its own thing and you should review every line of code in the function.
* **torch/fx/experimental/symbolic_shapes.py** We now can generate guards on float inputs, and these guards are handled inside of ShapeEnv. So we need to be able to allocate (backed!) float symbols, and produce guards for them. Fairly straightforward generalization.
* **torch/_dynamo/codegen.py** I also need to maintain the invariant that there are no float outputs to the FX graph. I chose to do this at codegen time. When we detect a SymNodeVariable on the return stack for a float, we on the fly convert it (via `as_tensor`) to a TensorVariable, which is the true output. We then special case the output bytecode to call item() on it again. The tensor conversion is memoized on SymNodeVariable since we typically run the code generation process twice.

Signed-off-by: Edward Z. Yang <[email protected]>

Pull Request resolved: pytorch#125325
Approved by: https://github.com/lezcano, https://github.com/jansel
tinglvv pushed a commit to tinglvv/pytorch that referenced this pull request May 14, 2024
pytorchmergebot pushed a commit that referenced this pull request May 14, 2024
Signed-off-by: Edward Z. Yang <[email protected]>

Pull Request resolved: #126074
Approved by: https://github.com/lezcano
ghstack dependencies: #125325, #125915
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants