Skip to content

Commit

Permalink
✨ feat(tools): 修改了onnxsim 函数
Browse files Browse the repository at this point in the history
  • Loading branch information
bruce1408 committed Mar 4, 2024
1 parent 54af254 commit 7c16dfe
Showing 1 changed file with 47 additions and 4 deletions.
51 changes: 47 additions & 4 deletions Tools/onnx_model_convert_onnxSimplify.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# import torch
# # import os
import os
import onnx
import onnxruntime as ort
from onnxsim import simplify
from printk import print_colored_box, print_colored_box_line


def onnx_simplify(path):
# simplify the onnx model & load onnx model
# simplify the onnx model & load onnx model and check if it is valid
onnx.load(path)
onnx_model = onnx.load(path)
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
Expand All @@ -14,8 +17,48 @@ def onnx_simplify(path):

onnx.save(model_simp, output_path)

print("The simplified ONNX model is saved to {}".format(output_path))
print_colored_box(f"The simplified ONNX model is saved in {output_path}")


def onnxruntime_inference(onnx_model_path):
# 加载ONNX模型
ort_session = ort.InferenceSession(onnx_model_path)

for input_tensor in ort_session.get_inputs():
# 获取输入的名称、形状和数据类型
input_name = input_tensor.name
input_shape = input_tensor.shape
print_colored_box(f"input name is {input_name}, Shape: {input_shape}, type is {input_tensor.type}", 50)


def print_onnx_input_output(model_path):
# 加载ONNX模型
model = onnx.load(model_path)

# 打印模型的输入信息
print("Model Inputs:")
for input in model.graph.input:
# 打印输入的形状和类型
shape = [dim.dim_value for dim in input.type.tensor_type.shape.dim]
data_type = input.type.tensor_type.elem_type

# print_colored_box(f"input name is {input.name}, Shape: {shape}, type is {onnx.TensorProto.DataType.Name(data_type)}", 50)
print(input.name, end=': ')

# 打印模型的输出信息
print("\nModel Outputs:")
for output in model.graph.output:
print(output.name, end=': ')
# 打印输出的形状和类型
shape = [dim.dim_value for dim in output.type.tensor_type.shape.dim]
data_type = output.type.tensor_type.elem_type
print("Shape:", shape, "Type:", onnx.TensorProto.DataType.Name(data_type))

if __name__=="__main__":
onnx_simplify("/mnt/share_disk/cdd/layernorm_with_official_opset_17.onnx")
model_path = "/Users/bruce/Downloads/8620_deploy/swin_tiny_patch4_window7_224_224_elementwise_affine.onnx"
# model_path = "/Users/bruce/Downloads/8620_deploy/Laneline/models/epoch_latest_0302.onnx"

onnx_simplify(model_path)
# print_onnx_input_output(model_path)
# onnxruntime_inference(model_path)

0 comments on commit 7c16dfe

Please sign in to comment.