Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

delegates prevents model conversion to TorchScript #494

Open
oguiza opened this issue Oct 21, 2022 · 0 comments
Open

delegates prevents model conversion to TorchScript #494

oguiza opened this issue Oct 21, 2022 · 0 comments

Comments

@oguiza
Copy link

oguiza commented Oct 21, 2022

Hi,
I've found an issue when trying to convert a Pytorch module. I've isolated the issue and created this snippet to reproduce it.
As a temporal workaround, I had to stop using delegates in the tsai library to avoid this issue, but I'd appreciate any help with this. Have any of you experienced this before?

THIS WORKS

import torch
import torch.nn as nn
from fastcore.meta import delegates

class DelegatesTest(nn.Module): 
    def __init__(self, **kwargs):  
        super().__init__(**kwargs)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)
    def forward(self, x):
        return self.conv(x)

n = Net()
inp = torch.rand(1, 1, 3, 3)
output = n(inp)
print(output)
module = torch.jit.trace(n, inp)
print(module)

THIS DOESN'T WORK

import torch
import torch.nn as nn
from fastcore.meta import delegates

@delegates(nn.Linear.__init__)
class DelegatesTest(nn.Module): 
    def __init__(self, **kwargs):  
        super().__init__(**kwargs)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)
    def forward(self, x):
        return self.conv(x)

n = Net()
inp = torch.rand(1, 1, 3, 3)
output = n(inp)
print(output)
module = torch.jit.trace(n, inp)
print(module)

It returns the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/var/folders/42/4hhwknbd5kzcbq48tmy_gbp00000gn/T/ipykernel_91620/2045013617.py in <module>
     19 output = n(inp)
     20 print(output)
---> 21 module = torch.jit.trace(n, inp)
     22 print(module)

~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_trace.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    757             strict,
    758             _force_outplace,
--> 759             _module_class,
    760         )
    761 

~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_trace.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    949         register_submods(mod, "__module")
    950 
--> 951         module = make_module(mod, _module_class, _compilation_unit)
    952 
    953         for method_name, example_inputs in inputs.items():

~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_trace.py in make_module(mod, _module_class, _compilation_unit)
    575         if _module_class is None:
    576             _module_class = TopLevelTracedModule
--> 577         return _module_class(mod, _compilation_unit=_compilation_unit)
    578 
    579 

~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_trace.py in __init__(self, orig, id_set, _compilation_unit)
   1075                 continue
   1076             tmp_module._modules[name] = make_module(
-> 1077                 submodule, TracedModule, _compilation_unit=None
   1078             )
   1079 

~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_trace.py in make_module(mod, _module_class, _compilation_unit)
    575         if _module_class is None:
    576             _module_class = TopLevelTracedModule
--> 577         return _module_class(mod, _compilation_unit=_compilation_unit)
    578 
    579 

~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_trace.py in __init__(self, orig, id_set, _compilation_unit)
   1079 
   1080         script_module = torch.jit._recursive.create_script_module(
-> 1081             tmp_module, lambda module: (), share_types=False, is_tracing=True
   1082         )
   1083 

~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
    453     assert not isinstance(nn_module, torch.jit.RecursiveScriptModule)
    454     check_module_initialized(nn_module)
--> 455     concrete_type = get_module_concrete_type(nn_module, share_types)
    456     if not is_tracing:
    457         AttributeTypeIsSupportedChecker().check(nn_module)

~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_recursive.py in get_module_concrete_type(nn_module, share_types)
    408         # Get a concrete type directly, without trying to re-use an existing JIT
    409         # type from the type store.
--> 410         concrete_type_builder = infer_concrete_type_builder(nn_module, share_types)
    411         concrete_type_builder.set_poisoned()
    412         concrete_type = concrete_type_builder.build()

~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_recursive.py in infer_concrete_type_builder(nn_module, share_types)
    220     # Constants annotated via `Final[T]` rather than being added to `__constants__`
    221     for name, ann in class_annotations.items():
--> 222         if torch._jit_internal.is_final(ann):
    223             constants_set.add(name)
    224 

~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/_jit_internal.py in is_final(ann)
    941 
    942 def is_final(ann) -> bool:
--> 943     return ann.__module__ in {'typing', 'typing_extensions'} and \
    944         (getattr(ann, '__origin__', None) is Final or isinstance(ann, type(Final)))
    945 

AttributeError: 'NoneType' object has no attribute '__module__'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant