Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ONNX export of selected filterbank modules #10

Open
faroit opened this issue Jan 22, 2021 · 25 comments
Open

Support ONNX export of selected filterbank modules #10

faroit opened this issue Jan 22, 2021 · 25 comments

Comments

@faroit
Copy link
Contributor

faroit commented Jan 22, 2021

one of the benefits of 1d conv based filterbanks is that they can be more easily exported for deployment.

testing TorchSTFTFB reveals that onnx export doesn't currently work and its not clear where the error stems from due to this.

example of traced module of the encoder exported with onnx:

    import torch.onnx
    from asteroid_filterbanks.enc_dec import Encoder
    from asteroid_filterbanks import torch_stft_fb

    nb_samples = 1
    nb_channels = 2
    nb_timesteps = 11111

    example = torch.rand((nb_samples, nb_channels, nb_timesteps))

    fb = torch_stft_fb.TorchSTFTFB(n_filters=512, kernel_size=512)
    enc = Encoder(fb)
    torch_out = enc(example)
    # Export the model
    torch.onnx.export(
        enc,
        example,
        "umx.onnx",
        export_params=True,
        opset_version=10,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        verbose=True
    )

results in

Traceback (most recent call last):
  File "onnx.py", line 28, in <module>
    verbose=False
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/__init__.py", line 230, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 91, in export
    use_external_data_format=use_external_data_format)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 639, in _export
    dynamic_axes=dynamic_axes)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 421, in _model_to_graph
    dynamic_axes=dynamic_axes, input_names=input_names)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 203, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/__init__.py", line 263, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 968, in _run_symbolic_function
    torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/__init__.py", line 263, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 979, in _run_symbolic_function
    operator_export_type)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 888, in _find_symbolic_in_registry
    return sym_registry.get_registered_op(op_name, domain, opset_version)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/symbolic_registry.py", line 111, in get_registered_op
    raise RuntimeError(msg)
RuntimeError: Exporting the operator prim_Uninitialized to ONNX opset version 10 is not supported. Please open a bug to request ONNX export support for the missing operator.
@mpariente
Copy link
Contributor

Thanks for the issue.
Did you try the other filterbanks?

@faroit
Copy link
Contributor Author

faroit commented Jan 24, 2021

Yes, not much luck:

  • free, stft, param_sinc suffer from this: maybe avoiding x, y, y = data.shape could help.
  • analytic_free: neither rfft nor fft or ifft are available in onnx

@mpariente
Copy link
Contributor

Ok..
Not sure we have the bandwidth to sort this out without outside help, if you'd like to work on this, please do!

@faroit
Copy link
Contributor Author

faroit commented Jan 24, 2021

Not sure we have the bandwidth to sort this out without outside help, if you'd like to work on this, please do!

me, neither. I was just super excited to get a fb frontend working with onnxruntime for a desktop app demo. But for now we better watch pytorch upstream and re-check when the error messages get more precise. I would suggest to leave this open as a reminder.

@jonashaag
Copy link
Contributor

I recently had this problem too and simply switched to SciPy based STFT, ie. not having the filterbank in PyTorch.

@mpariente
Copy link
Contributor

I recently had this problem too and simply switched to SciPy based STFT, ie. not having the filterbank in PyTorch.

Then, problem arises when you want to have the iSTFT in the network, for time domain losses, right?

@mpariente
Copy link
Contributor

Maybe it's actually not that complicated to fix this.
We could make a new class, as simple as possible, that just makes the conv or transposed conv, with fixed filters. And we'd make an export method from Encoder and Decoder.

@mpariente
Copy link
Contributor

@faroit have you tried with simple nn.Conv1D?

@jonashaag
Copy link
Contributor

Then, problem arises when you want to have the iSTFT in the network, for time domain losses, right?

Yes but if you put the loss outside the model, no problem. I don't need the loss in the ONNX export.

@faroit
Copy link
Contributor Author

faroit commented Jan 24, 2021

I recently had this problem too and simply switched to SciPy based STFT, ie. not having the filterbank in PyTorch.

how did you export the scipy stft to onnx?

@jonashaag
Copy link
Contributor

I didn't, it is required to be installed at the ONNX export user.

@faroit
Copy link
Contributor Author

faroit commented Jan 24, 2021

I didn't, it is required to be installed at the ONNX export user.

@jonashaag okay, that wasn't the use-case I had in mind. Having the full end-to-end model in onnx gives you the flexibility to perform audio processing e.g. from node.js/electron without having to reimplement the pre/post pipeline in js. Python would not be an option in that case.

@faroit
Copy link
Contributor Author

faroit commented Jan 24, 2021

@faroit have you tried with simple nn.Conv1D?

yes, that works. There are also other STFT variants that can be exported. I guess its a trivial thing but maybe we won't be able to track this down without much effort until the error tracing improves

@faroit
Copy link
Contributor Author

faroit commented Jan 28, 2021

nnAudios implmentation seems to be onnx exportable. We might want to check the differences... KinWaiCheuk/nnAudio#23 (comment)

@KinWaiCheuk
Copy link

nnAudios implmentation seems to be onnx exportable. We might want to check the differences... KinWaiCheuk/nnAudio#23 (comment)

I used two nn.Conv1d to write my STFT class in nnAudio. One Conv1d is for the real part (cos kernels) another Conv1d is for the imaginary part (sin kernels). May I know how do you implement STFT in your asteroid-filterbanks?

@mpariente
Copy link
Contributor

Using the functional API.
Have a look at the Encoder.forward.

@mpariente
Copy link
Contributor

I can convert the free, param and STFT with pytorch nightly, but not the torch_stft version yet.
Regarding analytic filterbank, there is some hope here

@mpariente
Copy link
Contributor

I check all the hooks, the is only one that doesn't pass is pre_analysis, which does the padding.
This is the function.

Now that this is much narrower, Fabian, would you like to have a look?

@faroit
Copy link
Contributor Author

faroit commented Mar 12, 2021

@mpariente next issue, to address the decoder, torch.fold is not supported....

def square_ola(window: torch.Tensor, kernel_size: int, stride: int, n_frame: int) -> torch.Tensor:
window_sq = window.pow(2).view(1, -1, 1).repeat(1, 1, n_frame)
return torch.nn.functional.fold(
window_sq, (1, (n_frame - 1) * stride + kernel_size), (1, kernel_size), stride=(1, stride)
).squeeze(2)

whats good replacement?

@mpariente
Copy link
Contributor

Are you sure about that?
Replacing it will be very cumbersome

@faroit
Copy link
Contributor Author

faroit commented Mar 12, 2021

Are you sure about that?
Replacing it will be very cumbersome

seems so: pytorch/pytorch#41423

@mpariente
Copy link
Contributor

mpariente commented Mar 12, 2021 via email

@faroit
Copy link
Contributor Author

faroit commented Jul 22, 2022

@mpariente @jonashaag works with torch 12 and opset > 11 now!

Should I still ad some tests?

@mpariente
Copy link
Contributor

Cool, thanks Fabian !

If you could, that'd be great !

@DakeQQ
Copy link

DakeQQ commented Dec 20, 2024

Feel free to use this repo to export your custom STFT or ISTFT process to ONNX format easily with torch.onnx.export().

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants