-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support ONNX export of selected filterbank modules #10
Comments
Thanks for the issue. |
Yes, not much luck:
|
Ok.. |
me, neither. I was just super excited to get a fb frontend working with onnxruntime for a desktop app demo. But for now we better watch pytorch upstream and re-check when the error messages get more precise. I would suggest to leave this open as a reminder. |
I recently had this problem too and simply switched to SciPy based STFT, ie. not having the filterbank in PyTorch. |
Then, problem arises when you want to have the iSTFT in the network, for time domain losses, right? |
Maybe it's actually not that complicated to fix this. |
@faroit have you tried with simple nn.Conv1D? |
Yes but if you put the loss outside the model, no problem. I don't need the loss in the ONNX export. |
how did you export the scipy stft to onnx? |
I didn't, it is required to be installed at the ONNX export user. |
@jonashaag okay, that wasn't the use-case I had in mind. Having the full end-to-end model in onnx gives you the flexibility to perform audio processing e.g. from node.js/electron without having to reimplement the pre/post pipeline in js. Python would not be an option in that case. |
yes, that works. There are also other STFT variants that can be exported. I guess its a trivial thing but maybe we won't be able to track this down without much effort until the error tracing improves |
nnAudios implmentation seems to be onnx exportable. We might want to check the differences... KinWaiCheuk/nnAudio#23 (comment) |
I used two nn.Conv1d to write my STFT class in nnAudio. One Conv1d is for the real part (cos kernels) another Conv1d is for the imaginary part (sin kernels). May I know how do you implement STFT in your asteroid-filterbanks? |
Using the functional API. |
I can convert the free, param and STFT with pytorch nightly, but not the torch_stft version yet. |
I check all the hooks, the is only one that doesn't pass is Now that this is much narrower, Fabian, would you like to have a look? |
@mpariente next issue, to address the decoder, torch.fold is not supported.... asteroid-filterbanks/asteroid_filterbanks/torch_stft_fb.py Lines 172 to 176 in 3510292
whats good replacement? |
Are you sure about that? |
seems so: pytorch/pytorch#41423 |
Did you try with the function API? It's probably the same but worth a try.
Le ven. 12 mars 2021 à 15:08, Fabian-Robert Stöter ***@***.***>
a écrit :
… Are you sure about that?
Replacing it will be very cumbersome
seems so: pytorch/pytorch#41423
<pytorch/pytorch#41423>
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#10 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AEND2HD7GPRVPEM6XEG7NCLTDIN47ANCNFSM4WOJD4HQ>
.
|
@mpariente @jonashaag works with torch 12 and opset > 11 now! Should I still ad some tests? |
Cool, thanks Fabian ! If you could, that'd be great ! |
Feel free to use this repo to export your custom STFT or ISTFT process to ONNX format easily with |
one of the benefits of 1d conv based filterbanks is that they can be more easily exported for deployment.
testing
TorchSTFTFB
reveals that onnx export doesn't currently work and its not clear where the error stems from due to this.example of traced module of the encoder exported with onnx:
results in
The text was updated successfully, but these errors were encountered: