Skip to content

Commit

Permalink
add resnet export of onnx (#341)
Browse files Browse the repository at this point in the history
* add checkpoint_sync_export for resnet config
  • Loading branch information
liaogulou authored Jul 2, 2024
1 parent 1d1ac8a commit 8c90cea
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@
depth=50,
out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='BN')))

checkpoint_sync_export = True
export = dict(export_type='raw', export_neck=True)
43 changes: 40 additions & 3 deletions easycv/apis/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,37 @@ def _get_blade_model():
torch.jit.save(blade_model, ofile)


def _export_onnx_cls(model, model_config, cfg, filename, meta):

if model_config['backbone'].get(
'type', None) == 'ResNet' and model_config['backbone'].get(
'depth', None) == 50:
# save json config for test_pipline and class
with io.open(
filename +
'.config.json' if filename.endswith('onnx') else filename +
'.onnx.config.json', 'w') as ofile:
json.dump(meta, ofile)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.eval()
model.to(device)
img_size = int(cfg.image_size2)
x_input = torch.randn((1, 3, img_size, img_size)).to(device)
torch.onnx.export(
model,
(x_input, 'onnx'),
filename if filename.endswith('onnx') else filename + '.onnx',
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
)
else:
raise ValueError('Only support export onnx model for ResNet now!')


def _export_cls(model, cfg, filename):
""" export cls (cls & metric learning)model and preprocess config
Expand All @@ -170,6 +201,7 @@ def _export_cls(model, cfg, filename):
else:
export_cfg = dict(export_neck=False)

export_type = export_cfg.get('export_type', 'raw')
export_neck = export_cfg.get('export_neck', True)
label_map_path = cfg.get('label_map_path', None)
class_list = None
Expand Down Expand Up @@ -232,9 +264,14 @@ def _export_cls(model, cfg, filename):
if export_neck and (k.startswith('neck') or k.startswith('head')):
state_dict[k] = v

checkpoint = dict(state_dict=state_dict, meta=meta, author='EasyCV')
with io.open(filename, 'wb') as ofile:
torch.save(checkpoint, ofile)
if export_type == 'raw':
checkpoint = dict(state_dict=state_dict, meta=meta, author='EasyCV')
with io.open(filename, 'wb') as ofile:
torch.save(checkpoint, ofile)
elif export_type == 'onnx':
_export_onnx_cls(model, model_config, cfg, filename, config)
else:
raise ValueError('Only support export onnx/raw model!')


def _export_yolox(model, cfg, filename):
Expand Down
17 changes: 17 additions & 0 deletions easycv/models/classification/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,20 @@ def forward_backbone(self, img: torch.Tensor) -> List[torch.Tensor]:
x = self.backbone(img)
return x

def forward_onnx(self, img: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
forward_onnx means generate prob from image only support one neck + one head
"""
x = self.forward_backbone(img) # tuple

# if self.neck_num > 0:
if hasattr(self, 'neck_0'):
x = self.neck_0([i for i in x])

out = self.head_0(x)[0].cpu()
out = self.activate_fn(out)
return out

@torch.jit.unused
def forward_train(self, img, gt_labels) -> Dict[str, torch.Tensor]:
"""
Expand Down Expand Up @@ -290,6 +304,9 @@ def forward(
return self.forward_test_label(img, gt_labels)
else:
return self.forward_test(img)
elif mode == 'onnx':
return self.forward_onnx(img)

elif mode == 'extract':
rd = self.forward_feature(img)
rv = {}
Expand Down
58 changes: 58 additions & 0 deletions easycv/predictors/classifier.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import glob
import math
import os

import numpy as np
import torch
from PIL import Image

from easycv.file import io
from easycv.framework.errors import ValueError
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.misc import deprecated
from .base import InputProcessor, OutputProcessor, Predictor, PredictorV2
from .builder import PREDICTORS


# onnx specific
def onnx_to_numpy(tensor):
return tensor.detach().cpu().numpy(
) if tensor.requires_grad else tensor.cpu().numpy()


class ClsInputProcessor(InputProcessor):
"""Process inputs for classification models.
Expand Down Expand Up @@ -146,6 +155,20 @@ def __init__(self,
self.pil_input = pil_input
self.label_map_path = label_map_path

if model_path.endswith('onnx'):
self.model_type = 'onnx'
pwd_model = os.path.dirname(model_path)
raw_model = glob.glob(
os.path.join(pwd_model, '*.onnx.config.json'))
if len(raw_model) != 0:
config_file = raw_model[0]
else:
assert len(
raw_model
) == 0, 'Please have a file with the .onnx.config.json extension in your directory'
else:
self.model_type = 'raw'

if self.pil_input:
mode = 'RGB'
super(ClassificationPredictor, self).__init__(
Expand Down Expand Up @@ -186,6 +209,41 @@ def get_output_processor(self):

return ClsOutputProcessor(topk=self.topk, label_map=self.label_map)

def prepare_model(self):
"""Build model from config file by default.
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
"""
if self.model_type == 'raw':
model = self._build_model()
model.to(self.device)
model.eval()
load_checkpoint(model, self.model_path, map_location='cpu')
return model
else:
import onnxruntime
if onnxruntime.get_device() == 'GPU':
onnx_model = onnxruntime.InferenceSession(
self.model_path, providers=['CUDAExecutionProvider'])
else:
onnx_model = onnxruntime.InferenceSession(self.model_path)

return onnx_model

def model_forward(self, inputs):
"""Model forward.
If you need refactor model forward, you need to reimplement it.
"""
with torch.no_grad():
if self.model_type == 'raw':
outputs = self.model(**inputs, mode='test')
else:
outputs = self.model.run(None, {
self.model.get_inputs()[0].name:
onnx_to_numpy(inputs['img'])
})[0]
outputs = dict(prob=torch.from_numpy(outputs))
return outputs


try:
from easy_vision.python.inference.predictor import PredictorInterface
Expand Down
4 changes: 2 additions & 2 deletions easycv/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# GENERATED VERSION FILE
# TIME: Thu Nov 5 14:17:50 2020

__version__ = '0.11.6'
short_version = '0.11.6'
__version__ = '0.11.7'
short_version = '0.11.7'
22 changes: 22 additions & 0 deletions tests/test_apis/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def test_export_cls_syncbn(self):
cfg = mmcv_config_fromfile(config_file)
cfg_options = {
'model.backbone.norm_cfg.type': 'SyncBN',
'export.export_type': 'raw'
}
if cfg_options is not None:
cfg.merge_from_dict(cfg_options)
Expand Down Expand Up @@ -210,6 +211,27 @@ def test_export_stgcn_jit(self):

self.assertTrue(os.path.exists(filename + '.jit'))

def test_export_resnet_onnx(self):

ckpt_path = PRETRAINED_MODEL_RESNET50

easycv_dir = os.path.dirname(easycv.__file__)

if os.path.exists(os.path.join(easycv_dir, 'configs')):
config_dir = os.path.join(easycv_dir, 'configs')
else:
config_dir = os.path.join(os.path.dirname(easycv_dir), 'configs')
config_file = os.path.join(
config_dir,
'classification/imagenet/resnet/imagenet_resnet50_jpg.py')

with tempfile.TemporaryDirectory() as tmpdir:
cfg = mmcv_config_fromfile(config_file)
cfg.export.export_type = 'onnx'
filename = os.path.join(tmpdir, 'imagenet_resnet50')
export(cfg, ckpt_path, filename)
self.assertTrue(os.path.exists(filename + '.onnx'))


if __name__ == '__main__':
unittest.main()
14 changes: 13 additions & 1 deletion tests/test_predictors/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from easycv.predictors.classifier import ClassificationPredictor
from easycv.utils.test_util import clean_up, get_tmp_dir
from tests.ut_config import (PRETRAINED_MODEL_RESNET50_WITHOUTHEAD,
IMAGENET_LABEL_TXT, TEST_IMAGES_DIR)
IMAGENET_LABEL_TXT, TEST_IMAGES_DIR,
PRETRAINED_MODEL_RESNET50_ONNX_WITHOUTHEAD)


class ClassificationPredictorTest(unittest.TestCase):
Expand All @@ -33,6 +34,17 @@ def test_single(self):
self.assertListEqual(results['class_name'], ['"Persian cat",'])
self.assertEqual(len(results['class_probs']), 1000)

def test_onnx_single(self):
checkpoint = PRETRAINED_MODEL_RESNET50_ONNX_WITHOUTHEAD
predict_op = ClassificationPredictor(model_path=checkpoint)

img_path = os.path.join(TEST_IMAGES_DIR, 'catb.jpg')

results = predict_op([img_path])[0]
self.assertListEqual(results['class'], [578])
self.assertListEqual(results['class_name'], ['gown'])
self.assertEqual(len(results['class_probs']), 1000)

def test_batch(self):
checkpoint = PRETRAINED_MODEL_RESNET50_WITHOUTHEAD
config_file = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py'
Expand Down
16 changes: 8 additions & 8 deletions tests/test_predictors/test_pose_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def _base_test(self, predictor):

assert_array_almost_equal(
result0['bbox'],
np.array([[352.3085, 59.00325, 691.4247, 511.15814, 1.],
[10.511196, 177.74883, 101.824326, 299.49966, 1.],
[224.82036, 114.439865, 312.51306, 231.36348, 1.],
[200.71407, 114.716736, 337.17535, 296.6651, 1.]],
np.array([[438.9, 59., 604.8, 511.2, 0.9],
[10.5, 179.6, 101.8, 297.7, 0.9],
[229.6, 114.4, 307.8, 231.4, 0.6],
[229.4, 114.7, 308.5, 296.7, 0.6]],
dtype=np.float32),
decimal=1)
vis_result = predictor.show_result(img1, result0)
Expand Down Expand Up @@ -92,10 +92,10 @@ def _base_test(self, predictor):

assert_array_almost_equal(
result1['bbox'][:4],
np.array([[436.23096, 214.72766, 584.26013, 412.09985, 1.],
[43.990044, 91.04126, 164.28406, 251.43329, 1.],
[127.44148, 100.38604, 254.219, 269.42273, 1.],
[190.08075, 117.31801, 311.22394, 278.8423, 1.]],
np.array([[470.6, 214.7, 549.9, 412.1, 0.9],
[71.6, 91., 136.7, 251.4, 0.9],
[159.7, 100.4, 221.9, 269.4, 0.9],
[219.4, 117.3, 281.9, 278.8, 0.9]],
dtype=np.float32),
decimal=1)
vis_result = predictor.show_result(img2, result1)
Expand Down
3 changes: 3 additions & 0 deletions tests/ut_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@
PRETRAINED_MODEL_RESNET50_WITHOUTHEAD = os.path.join(
BASE_LOCAL_PATH,
'pretrained_models/classification/resnet/resnet50_withhead.pth')
PRETRAINED_MODEL_RESNET50_ONNX_WITHOUTHEAD = os.path.join(
BASE_LOCAL_PATH,
'pretrained_models/classification/resnet/imagenet_resnet50.onnx')
PRETRAINED_MODEL_FACEID = os.path.join(BASE_LOCAL_PATH,
'pretrained_models/faceid')
PRETRAINED_MODEL_YOLOXS_EXPORT = os.path.join(
Expand Down

0 comments on commit 8c90cea

Please sign in to comment.