Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can't explore UNet2DConditionModel from diffusers lib #136

Closed
hanruijiang opened this issue Aug 27, 2024 · 3 comments
Closed

can't explore UNet2DConditionModel from diffusers lib #136

hanruijiang opened this issue Aug 27, 2024 · 3 comments
Assignees
Labels
PyTorch adapter For issues where the root is mostly occurring with the PyTorch adapter stat:awaiting model-navigator The issue is actively being worked on by our navigators type:bug Bug

Comments

@hanruijiang
Copy link

Hi,

I can't run model_explorer on UNet2DConditionModel from diffusers lib.

Could you please help me solve this problem?

Environment

google colab

torch 2.4.0+cu121
torchvision 0.19.0+cu121
diffusers 0.30.1
ai-edge-model-explorer 0.1.10
ai-edge-model-explorer-adapter 0.1.5

Steps to Reproduce:

!pip install ai-edge-model-explorer
!pip install diffusers

import os
import torch
import torchvision
from diffusers import UNet2DConditionModel
import model_explorer

import requests
url = 'https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder/resolve/main/unet/config.json'

response = requests.get(url)
config = response.json()

unet = UNet2DConditionModel(**config)

image_embeds = torch.randn((1, 1280))
latent_model_input = torch.randn((1, 4, 64, 64))
t = torch.tensor(1.).long()
added_cond_kwargs = {"image_embeds": image_embeds}

ep = torch.export.export(
    unet,
    args=(latent_model_input, t, None),
    kwargs={
        'added_cond_kwargs': added_cond_kwargs,
        'return_dict': False
    },
)

model_explorer.visualize_pytorch('unet', exported_program=ep)

Error log

Converting pytorch model to model explorer graphs...
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-7-8163fd1130fd> in <cell line: 3>()
      1 import model_explorer
      2 # Visualize
----> 3 model_explorer.visualize_pytorch('unet', exported_program=ep)

/usr/local/lib/python3.10/dist-packages/model_explorer/apis.py in visualize_pytorch(name, exported_program, host, port, extensions, colab_height, settings)
     80   # Construct config.
     81   cur_config = config()
---> 82   cur_config.add_model_from_pytorch(
     83       name, exported_program=exported_program, settings=settings
     84   )

/usr/local/lib/python3.10/dist-packages/model_explorer/config.py in add_model_from_pytorch(self, name, exported_program, settings)
     88     print('Converting pytorch model to model explorer graphs...')
     89     adapter = PytorchExportedProgramAdapterImpl(exported_program, settings)
---> 90     graphs = adapter.convert()
     91     graphs_index = len(self.graphs_list)
     92     self.graphs_list.append(graphs)

/usr/local/lib/python3.10/dist-packages/model_explorer/pytorch_exported_program_adater_impl.py in convert(self)
    236 
    237   def convert(self) -> ModelExplorerGraphs:
--> 238     return {'graphs': [self.create_graph()]}

/usr/local/lib/python3.10/dist-packages/model_explorer/pytorch_exported_program_adater_impl.py in create_graph(self)
    232     graph = Graph(id='graph', nodes=[])
    233     for node in self.gm.graph.nodes:
--> 234       graph.nodes.append(self.create_node(node))
    235     return graph
    236 

/usr/local/lib/python3.10/dist-packages/model_explorer/pytorch_exported_program_adater_impl.py in create_node(self, fx_node)
    225     )
    226     self.add_incoming_edges(fx_node, node)
--> 227     self.add_node_attrs(fx_node, node)
    228     self.add_outputs_metadata(fx_node, node)
    229     return node

/usr/local/lib/python3.10/dist-packages/model_explorer/pytorch_exported_program_adater_impl.py in add_node_attrs(self, fx_node, node)
    183               KeyValue(
    184                   key='__value',
--> 185                   value=self.print_tensor(
    186                       tensor, self.settings['const_element_count_limit']
    187                   ),

/usr/local/lib/python3.10/dist-packages/model_explorer/pytorch_exported_program_adater_impl.py in print_tensor(self, tensor, size_limit)
    152 
    153   def print_tensor(self, tensor: torch.Tensor, size_limit: int = 16):
--> 154     shape = tensor.shape
    155     total_size = 1
    156     for dim in shape:

AttributeError: 'bool' object has no attribute 'shape'
@pkgoogle
Copy link
Contributor

Hi @hanruijiang, I was able to replicate exactly as you described ... it seems like model explorer expects all arguments to be tensors: https://github.com/google-ai-edge/model-explorer/blob/main/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py#L177

Hi @yijie-yang, can you please take a look?

@pkgoogle pkgoogle added type:bug Bug stat:awaiting model-navigator The issue is actively being worked on by our navigators PyTorch adapter For issues where the root is mostly occurring with the PyTorch adapter labels Aug 27, 2024
@hanruijiang
Copy link
Author

Hi @hanruijiang, I was able to replicate exactly as you described ... it seems like model explorer expects all arguments to be tensors: https://github.com/google-ai-edge/model-explorer/blob/main/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py#L177

Hi @yijie-yang, can you please take a look?

Hi @pkgoogle ,

You're right. I fixed this bug by remove 'return_dict': False in the kwargs .

Thank you for your help.

@pkgoogle
Copy link
Contributor

@hanruijiang, Thanks, yeah that worked for me as well... if you have no more open items, feel free to close, thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
PyTorch adapter For issues where the root is mostly occurring with the PyTorch adapter stat:awaiting model-navigator The issue is actively being worked on by our navigators type:bug Bug
Projects
None yet
Development

No branches or pull requests

3 participants