Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[guards][cpp-guards] Optimize NN module getattr guards #124522

Closed
wants to merge 27 commits into from

Conversation

Copy link

pytorch-bot bot commented Apr 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124522

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 4af936f with merge base 3b5f6b1 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

anijain2305 added a commit that referenced this pull request Apr 19, 2024
ghstack-source-id: 4589e8da5eca9343d7da392652e04cf92edd084e
Pull Request resolved: #124522
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Apr 22, 2024
ghstack-source-id: 867b3748f8bc63c966f084ea68a03b2e505f8619
Pull Request resolved: #124522
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 34000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Apr 22, 2024
ghstack-source-id: 814798e76bdefd38ca37aa15ab565f464fee4cee
Pull Request resolved: #124522
@anijain2305 anijain2305 changed the title [wip][guards][cpp-guards] Optimize NN module getattr guards [guards][cpp-guards] Optimize NN module getattr guards Apr 22, 2024
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 34000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
@anijain2305 anijain2305 added the keep-going Don't stop on first failure, keep running tests until the end label Apr 22, 2024
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 34000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 34000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 38000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 38000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
@anijain2305 anijain2305 marked this pull request as draft April 22, 2024 22:12
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 20000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Apr 23, 2024
ghstack-source-id: 7b52ca98577226915b946408da6659061c99bc11
Pull Request resolved: #124522
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 20000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Apr 23, 2024
ghstack-source-id: 864bac523df4e3c9148e907420beac49b1ad2c2c
Pull Request resolved: #124522
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 20000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Apr 24, 2024
ghstack-source-id: 52d5675380bbd6c72a217b3c1c1c3a72599dafee
Pull Request resolved: #124522
@anijain2305 anijain2305 requested a review from jansel April 26, 2024 18:35
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 20000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some more tests for this

mgr, key, source_name, base_example_value, example_value, guard_manager_enum
):
if isinstance(mgr, DictGuardManager):
# Case where the user code relies on key order, e.g.,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there tests covering this case? (Where we change the order)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were none for NNModuleVariableTracker. There are for UnSpecializedNNModuleVariable.

Good catch. Added and made the fixes as well.

for x in inspect.getmro(base_example_value.__class__):
all_class_attribute_names.update(x.__dict__.keys())

if attr_name in all_class_attribute_names:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class attributes are shadowed by module dict attributes... I think the order of these are wrong. add some tests for shadowing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right .. Thanks for pointing this out, I will fix it. But there is a corner case with submodules, where class attr take precedence. This is the example

import torch
import torch._dynamo.testing

torch._dynamo.config.guard_nn_modules = True


class Mod(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = 3

    def forward(self,x, c=4):
        return x * c

    def linear(self, x):
        return x

    # b is class attribute and this is a bug that jansel pointed out
    # mod.b overshadows Mod.b
    def b(self, x):
        assert False
        return x

class MyMod(Mod):
    def __init__(self):
        super().__init__()
        # IF THIS LINE ACTUALLY CHAGES LINEAR, IT WOULD FAIL WITH SHAPE MISMATCH ERROR
        self.linear = torch.nn.Linear(11, 11)
        self.a = 2
        self.b = 2


    def forward(self, x, c=None):
        return self.linear(x) * self.a * self.b


mod = MyMod()
mod(torch.ones(3, 3))

# Instance attributes overwrite the instance attributes of the parent class.
assert mod.a == 2

# ANOMALY - mod.linear is not the parameter
assert mod.linear.__code__ is Mod.linear.__code__

# jansel's valid comment - module dict attr overshadow the class attrs
# anijain2305 - This does not cause a bug because today we fallback to getattr,
# but we are losing perf opportunity here.
assert mod.b == 2

cnts = torch._dynamo.testing.CompileCounter()
opt_mod = torch.compile(mod, backend=cnts)
opt_mod(torch.ones(3, 3))
opt_mod(torch.ones(3, 3))

assert cnts.frame_count == 1

The discrepancy is for mod.linear. Parent class has a function named linear. When the derived class instance calls self.linear, it calls the parent class function.

This is a special case. Apart from that everything else conforms to the general rule - "class attributes are shadowed by the mod dict attributes".

I also think this is a bug in __setattr__ of nn.Module. For example, I am unable to assign a self.param if the parent class already has a function named param. If you think, this is just a bug, I can send a PR for nn modules.

Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 20000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 20000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 20000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 20000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 20000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 20000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 20000 units.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 2, 2024
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 20000 units.


cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k voznesenskym EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 20000 units.


cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k voznesenskym EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 20000 units.


cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k voznesenskym EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
@anijain2305
Copy link
Contributor Author

@jansel just a gentle ping in case this slipped in your notifications.

pytorchmergebot pushed a commit that referenced this pull request May 4, 2024
fathnd pushed a commit to fathnd/homomorphic that referenced this pull request May 5, 2024
ghstack-source-id: 5af5a978ab8b672fcfe1ee0622f48ff9d376635f
Pull Request resolved: pytorch/pytorch#124522
@github-actions github-actions bot deleted the gh/anijain2305/292/head branch June 5, 2024 01:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/slow keep-going Don't stop on first failure, keep running tests until the end Merged module: dynamo oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants