From 920c721ea77315291f987ea5b911be7fb619d134 Mon Sep 17 00:00:00 2001 From: Nick Martin <284356+n1mmy@users.noreply.github.com> Date: Sun, 5 May 2024 02:55:12 -0700 Subject: [PATCH] Backport compatibility with TensorRT version 10 from yolov8 (#12984) Add compatibility with TensorRT version 10. Based on the is_trt10 code in yolov8. --- export.py | 14 +++++++++----- models/common.py | 40 ++++++++++++++++++++++++++++------------ 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/export.py b/export.py index c660ce660aa5..214d903c2998 100644 --- a/export.py +++ b/export.py @@ -346,6 +346,7 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose onnx = file.with_suffix(".onnx") LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...") + is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10 assert onnx.exists(), f"failed to export ONNX file: {onnx}" f = file.with_suffix(".engine") # TensorRT engine file logger = trt.Logger(trt.Logger.INFO) @@ -354,9 +355,10 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose builder = trt.Builder(logger) config = builder.create_builder_config() - config.max_workspace_size = workspace * 1 << 30 - # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice - + if is_trt10: + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) + else: # TensorRT versions 7, 8 + config.max_workspace_size = workspace * 1 << 30 flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) network = builder.create_network(flag) parser = trt.OnnxParser(network, logger) @@ -381,8 +383,10 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose LOGGER.info(f"{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}") if builder.platform_has_fast_fp16 and half: config.set_flag(trt.BuilderFlag.FP16) - with builder.build_engine(network, config) as engine, open(f, "wb") as t: - t.write(engine.serialize()) + + build = builder.build_serialized_network if is_trt10 else builder.build_engine + with build(network, config) as engine, open(f, "wb") as t: + t.write(engine if is_trt10 else engine.serialize()) return f, None diff --git a/models/common.py b/models/common.py index 8925897099c1..12244fd4b3cf 100644 --- a/models/common.py +++ b/models/common.py @@ -527,18 +527,34 @@ def __init__(self, weights="yolov5s.pt", device=torch.device("cpu"), dnn=False, output_names = [] fp16 = False # default updated below dynamic = False - for i in range(model.num_bindings): - name = model.get_binding_name(i) - dtype = trt.nptype(model.get_binding_dtype(i)) - if model.binding_is_input(i): - if -1 in tuple(model.get_binding_shape(i)): # dynamic - dynamic = True - context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2])) - if dtype == np.float16: - fp16 = True - else: # output - output_names.append(name) - shape = tuple(context.get_binding_shape(i)) + is_trt10 = not hasattr(model, "num_bindings") + num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings) + for i in num: + if is_trt10: + name = model.get_tensor_name(i) + dtype = trt.nptype(model.get_tensor_dtype(name)) + is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT + if is_input: + if -1 in tuple(model.get_tensor_shape(name)): # dynamic + dynamic = True + context.set_input_shape(name, tuple(model.get_profile_shape(name, 0)[2])) + if dtype == np.float16: + fp16 = True + else: # output + output_names.append(name) + shape = tuple(context.get_tensor_shape(name)) + else: + name = model.get_binding_name(i) + dtype = trt.nptype(model.get_binding_dtype(i)) + if model.binding_is_input(i): + if -1 in tuple(model.get_binding_shape(i)): # dynamic + dynamic = True + context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2])) + if dtype == np.float16: + fp16 = True + else: # output + output_names.append(name) + shape = tuple(context.get_binding_shape(i)) im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())