diff --git a/yolort/runtime/ort_helper.py b/yolort/runtime/ort_helper.py index 5ba21cc0..042f5a2b 100644 --- a/yolort/runtime/ort_helper.py +++ b/yolort/runtime/ort_helper.py @@ -20,6 +20,7 @@ def export_onnx( skip_preprocess: bool = False, opset_version: int = 11, batch_size: int = 1, + half: bool = False, ) -> None: """ Export to ONNX models that can be used for ONNX Runtime inferencing. @@ -55,6 +56,7 @@ def export_onnx( skip_preprocess=skip_preprocess, opset_version=opset_version, batch_size=batch_size, + half=half, ) onnx_builder.to_onnx(onnx_path) @@ -94,6 +96,7 @@ def __init__( skip_preprocess: bool = False, opset_version: int = 11, batch_size: int = 1, + half: bool = False, ) -> None: super().__init__() @@ -111,6 +114,7 @@ def __init__( if model is None: model = self._build_model() self.model = model + self.half = half # For exporting ONNX model self._opset_version = opset_version @@ -215,3 +219,15 @@ def to_onnx(self, onnx_path: str, **kwargs): dynamic_axes=self.dynamic_axes, **kwargs, ) + if self.half: + try: + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + except Exception as e: + print( + f"onnxmltools dependency should be installed,do not build fp16 onnx,error message is {e}" + ) + else: + onnxFP32 = onnxmltools.utils.load_model(onnx_path) + onnxFP16 = convert_float_to_float16(onnxFP32) + onnxmltools.utils.save_model(onnxFP16, onnx_path)