Skip to content

Commit

Permalink
Merge pull request #170 from google-ai-edge/jingjinnodedata2
Browse files Browse the repository at this point in the history
Update API to make it more user-friendly by allowing users to pass a single node data item instead of a list
  • Loading branch information
jinjingforever authored Sep 12, 2024
2 parents 376d443 + 9226e39 commit ae97c48
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions src/server/package/src/model_explorer/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# limitations under the License.
# ==============================================================================

from typing import Union, TypedDict
from typing_extensions import NotRequired
from typing import TypedDict, Union

import torch
from typing_extensions import NotRequired

from . import server
from .config import ModelExplorerConfig, NodeData
Expand Down Expand Up @@ -58,7 +58,7 @@ def visualize(
host=DEFAULT_HOST,
port=DEFAULT_PORT,
extensions: list[str] = [],
node_data_list: list[NodeDataInfo] = [],
node_data: Union[NodeDataInfo, list[NodeDataInfo]] = [],
colab_height=DEFAULT_COLAB_HEIGHT,
reuse_server: bool = False,
reuse_server_host: str = DEFAULT_HOST,
Expand All @@ -71,7 +71,7 @@ def visualize(
host: The host of the server. Default to localhost.
port: The port of the server. Default to 8080.
extensions: List of extension names to be run with model explorer.
node_data_list: The list of node data to display.
node_data: The node data or a list of node data to display.
colab_height: The height of the embedded iFrame when running in colab.
reuse_server: Whether to reuse the current server/browser tab(s) to
visualize.
Expand All @@ -88,9 +88,7 @@ def visualize(
for model_path in model_paths_list:
cur_config.add_model_from_path(path=model_path)

_add_node_data_list_to_config(
node_data_list=node_data_list, config=cur_config
)
_add_node_data_to_config(node_data=node_data, config=cur_config)

if reuse_server:
cur_config.set_reuse_server(
Expand All @@ -113,7 +111,7 @@ def visualize_pytorch(
host=DEFAULT_HOST,
port=DEFAULT_PORT,
extensions: list[str] = [],
node_data_list: list[NodeDataInfo] = [],
node_data: Union[NodeDataInfo, list[NodeDataInfo]] = [],
colab_height=DEFAULT_COLAB_HEIGHT,
settings=DEFAULT_SETTINGS,
) -> None:
Expand All @@ -125,7 +123,7 @@ def visualize_pytorch(
host: The host of the server. Default to localhost.
port: The port of the server. Default to 8080.
extensions: List of extension names to be run with model explorer.
node_data_list: The list of node data to display.
node_data: The node data or a list of node data to display.
colab_height: The height of the embedded iFrame when running in colab.
settings: The settings that config the visualization.
"""
Expand All @@ -135,9 +133,7 @@ def visualize_pytorch(
name, exported_program=exported_program, settings=settings
)

_add_node_data_list_to_config(
node_data_list=node_data_list, config=cur_config
)
_add_node_data_to_config(node_data=node_data, config=cur_config)

# Start server.
server.start(
Expand Down Expand Up @@ -182,17 +178,25 @@ def visualize_from_config(
)


def _add_node_data_list_to_config(
node_data_list: list[NodeDataInfo], config: ModelExplorerConfig
def _add_node_data_to_config(
node_data: Union[NodeDataInfo, list[NodeDataInfo]],
config: ModelExplorerConfig,
):
# Convert NodeDataInfo to [NodeDataInfo] if necessary.
node_data_list: list[NodeDataInfo] = []
if isinstance(node_data, list):
node_data_list = node_data
else:
node_data_list = [node_data]

for node_data_info in node_data_list:
name = node_data_info.get('name', 'node data')
node_data_path = node_data_info.get('node_data_path')
node_data = node_data_info.get('node_data')
node_data_obj = node_data_info.get('node_data')
model_name = node_data_info.get('model_name')
if node_data:
if node_data_obj:
config.add_node_data(
name=name, node_data=node_data, model_name=model_name
name=name, node_data=node_data_obj, model_name=model_name
)
elif node_data_path:
config.add_node_data_from_path(path=node_data_path, model_name=model_name)

0 comments on commit ae97c48

Please sign in to comment.