Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch unsqueeze / reshape dimensions not preserved upon TFLite conversion #288

Open
gudgud96 opened this issue Oct 8, 2024 · 1 comment

Comments

@gudgud96
Copy link

gudgud96 commented Oct 8, 2024

Description of the bug:

I am working on making my model to be GPU-compatible on TFLite, and one of the issues to solve is to bypass some operators that do not support broadcasting on GPU. One solution is to keep the tensors in the same rank through reshape or unsqueeze, but I find that the expanded dims are not preserved after TFLite conversion.

Suppose I have a minimal code snippet as below:

class TestModel(nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()
        self.linear = nn.Linear(8, 8)
        self.linear2 = nn.Linear(32, 32)
    def forward(self, x):
        """
        x - (1, 20, 32)
        """
        B, T, D = x.shape
        tmp = x
        x = x.reshape(4, B * T, D // 4)                         # (4, 20, 8)
        x = self.linear(x)                                      # (4, 20, 8)
        x = x.transpose(0, 1).contiguous().reshape(1, T, D)     # (1, 20, 32)
        x = self.linear2(x)                                     # (1, 20, 32)                 
        
        # residual add
        x = tmp + x
        x = x.squeeze(0)
        return x

test_model = TestModel()
x = torch.rand(1, 20, 32)
with torch.no_grad():
    y = test_model(x)

import ai_edge_torch
edge_model = ai_edge_torch.convert(test_model, (x,))
edge_model.export("test_model.tflite")

At the residual add line, it is expected that tmp and x has the same shape, (1, 20, 32).
However, after converting to TFLite, running

tf.lite.experimental.Analyzer.analyze(model_path="test_model.tflite", gpu_compatibility=True)

gives me incompatibility warnings, causing the model to not fully utilize GPU:

Subgraph#0 main(T#0) -> [T#14]
  Op#0 RESHAPE(T#0, T#7[4, 20, 8]) -> [T#8]
  Op#1 FULLY_CONNECTED(T#8, T#2, T#5) -> [T#9]
  Op#2 TRANSPOSE(T#9, T#4[1, 0, 2]) -> [T#10]
  Op#3 RESHAPE(T#10, T#3[20, 32]) -> [T#11]
  Op#4 FULLY_CONNECTED(T#11, T#1, T#6) -> [T#12]
  Op#5 ADD(T#0, T#12) -> [T#13]
GPU COMPATIBILITY WARNING: Doesn't support broadcasting - input0: [1,20,32], input1: [20,32]
  Op#6 RESHAPE(T#13, T#3[20, 32]) -> [T#14]
image

It seems like the conversion "squeezes" the Reshape op to output a 2-dim tensor, instead of a 3-dim tensor.

I would like to understand how the Reshape output shape is being optimized / changed during the conversion process. This would give a better understanding on how could I preserve the expanded dimensions, keep the Add tensors in the same shape, in order to bypass the Doesn't support broadcasting warnings.

Actual vs expected behavior:

Expected reshape dimensions to be preserved as Torch output tensor shapes after TFLite conversion.
Actual TFLite converted output tensors might be squeezed due to some unknown optimizations.

Any other information you'd like to share?

No response

@gudgud96 gudgud96 added the type:bug Bug label Oct 8, 2024
@pkgoogle pkgoogle self-assigned this Oct 8, 2024
@pkgoogle
Copy link
Contributor

pkgoogle commented Oct 8, 2024

I was able to replicate exactly as above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants