Skip to content

Commit

Permalink
Trt10 yolov8 (#1557)
Browse files Browse the repository at this point in the history
* The v5-cls model supports TensorRT10

* The v5-cls model supports TensorRT10 Python API

* add YOLOv5-cls readme

* pre-commit and modify trtx download branch

* pre-commit

* The v5 det model supports TensorRT10

* import pycuda.autoinit  # noqa: F401

* The v5 det model supports TensorRT10 Python API

* modeify readme

* modefiy readme

* modify reamde

* Delete the link of nvinfer_plugin

* Add TensorRT10 support for YOLOv8

* pre-commit

* delelet images,  modify add test images in readme
  • Loading branch information
mpj1234 authored Jul 29, 2024
1 parent c3e5d41 commit 54e0858
Show file tree
Hide file tree
Showing 30 changed files with 7,444 additions and 0 deletions.
57 changes: 57 additions & 0 deletions yolov8/yolov8_trt10/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
cmake_minimum_required(VERSION 3.10)

project(yolov8)

add_definitions(-std=c++11)
add_definitions(-DAPI_EXPORTS)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_BUILD_TYPE Debug)

set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc)
enable_language(CUDA)

include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_SOURCE_DIR}/plugin)

# include and link dirs of cuda and tensorrt, you need adapt them if yours are different
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
message("embed_platform on")
include_directories(/usr/local/cuda/targets/aarch64-linux/include)
link_directories(/usr/local/cuda/targets/aarch64-linux/lib)
else()
message("embed_platform off")

# cuda
include_directories(/usr/local/cuda/include)
link_directories(/usr/local/cuda/lib64)

# tensorrt
include_directories(/workspace/shared/TensorRT-10.2.0.19/include/)
link_directories(/workspace/shared/TensorRT-10.2.0.19/lib/)

# include_directories(/home/lindsay/TensorRT-7.2.3.4/include)
# link_directories(/home/lindsay/TensorRT-7.2.3.4/lib)
endif()

add_library(myplugins SHARED ${PROJECT_SOURCE_DIR}/plugin/yololayer.cu)
target_link_libraries(myplugins nvinfer cudart)

find_package(OpenCV)
include_directories(${OpenCV_INCLUDE_DIRS})

file(GLOB_RECURSE SRCS ${PROJECT_SOURCE_DIR}/src/*.cpp ${PROJECT_SOURCE_DIR}/src/*.cu)
add_executable(yolov8_det ${PROJECT_SOURCE_DIR}/yolov8_det.cpp ${SRCS})

target_link_libraries(yolov8_det nvinfer)
target_link_libraries(yolov8_det cudart)
target_link_libraries(yolov8_det myplugins)
target_link_libraries(yolov8_det ${OpenCV_LIBS})

add_executable(yolov8_seg ${PROJECT_SOURCE_DIR}/yolov8_seg.cpp ${SRCS})
target_link_libraries(yolov8_seg nvinfer cudart myplugins ${OpenCV_LIBS})

add_executable(yolov8_pose ${PROJECT_SOURCE_DIR}/yolov8_pose.cpp ${SRCS})
target_link_libraries(yolov8_pose nvinfer cudart myplugins ${OpenCV_LIBS})

add_executable(yolov8_cls ${PROJECT_SOURCE_DIR}/yolov8_cls.cpp ${SRCS})
target_link_libraries(yolov8_cls nvinfer cudart myplugins ${OpenCV_LIBS})
151 changes: 151 additions & 0 deletions yolov8/yolov8_trt10/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
## Introduce

Yolov8 model supports TensorRT-10.

## Environment

CUDA: 11.8
CUDNN: 8.9.1.23
TensorRT: TensorRT-10.2.0.19

## Support

* [x] YOLOv8-cls support FP32/FP16/INT8 and Python/C++ API
* [x] YOLOv8-det support FP32/FP16/INT8 and Python/C++ API
* [x] YOLOv8-seg support FP32/FP16/INT8 and Python/C++ API
* [x] YOLOv8-pose support FP32/FP16/INT8 and Python/C++ API

## Config

* Choose the YOLOv8 sub-model n/s/m/l/x/n6/s6/m6/l6/x6 from command line arguments.
* Other configs please check [src/config.h](src/config.h)

## Build and Run

1. generate .wts from pytorch with .pt, or download .wts from model zoo

```shell
git clone https://gitclone.com/github.com/ultralytics/ultralytics.git
git clone -b trt10 https://github.com/wang-xinyu/tensorrtx.git
cd yolov8/
wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n-cls.pt
wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt
wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n-seg.pt
wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n-pose.pt
cp [PATH-TO-TENSORRTX]/yolov8/gen_wts.py .
python gen_wts.py -w yolov8n-cls.pt -o yolov8n-cls.wts -t cls
python gen_wts.py -w yolov8n.pt -o yolov8n.wts
python gen_wts.py -w yolov8n-seg.pt -o yolov8n-seg.wts -t seg
python gen_wts.py -w yolov8n-pose.pt -o yolov8n-pose.wts -t pose
# A file 'yolov8n.wts' will be generated.
```

2. build tensorrtx/yolov8/yolov8_trt10 and run

#### Classification

```shell
cd [PATH-TO-TENSORRTX]/yolov8/yolov8_trt10

# add test images
mkdir images
cp [PATH-TO-TENSORRTX]/yolov3-spp/samples/*.jpg ./images

# Update kNumClass in src/config.h if your model is trained on custom dataset
mkdir build
cd build
cp [PATH-TO-ultralytics-yolov8]/yolov8sn-cls.wts .
cmake ..
make

# Download ImageNet labels
wget https://github.com/joannzhang00/ImageNet-dataset-classes-labels/blob/main/imagenet_classes.txt

# Build and serialize TensorRT engine
./yolov8_cls -s yolov8n-cls.wts yolov8n-cls.engine [n/s/m/l/x]

# Run inference
./yolov8_cls -d yolov8n-cls.engine ../images
# The results are displayed in the console
```

3. Optional, load and run the tensorrt model in Python
```shell
// Install python-tensorrt, pycuda, etc.
// Ensure the yolov8n-cls.engine
python yolov8_cls_trt.py ./build/yolov8n-cls.engine ../images
# faq: in windows bug pycuda._driver.LogicError
# faq: in linux bug Segmentation fault
# Add the following code to the py file:
# import pycuda.autoinit
# import pycuda.driver as cuda
```

#### Detection

```shell
cd [PATH-TO-TENSORRTX]/yolov8/yolov8_trt10
# Update kNumClass in src/config.h if your model is trained on custom dataset
mkdir build
cd build
cp [PATH-TO-ultralytics-yolov8]/yolov8n.wts .
cmake ..
make

# Build and serialize TensorRT engine
./yolov8_det -s yolov8n.wts yolov8n.engine [n/s/m/l/x]

# Run inference
./yolov8_det -d yolov8n.engine ../images [c/g]
# The results are displayed in the console
```

#### Segmentation

```shell
cd [PATH-TO-TENSORRTX]/yolov8/yolov8_trt10
# Update kNumClass in src/config.h if your model is trained on custom dataset
mkdir build
cd build
cp [PATH-TO-ultralytics-yolov8]/yolov8n-seg.wts .
cmake ..
make

# Build and serialize TensorRT engine
./yolov8_seg -s yolov8n-seg.wts yolov8n-seg.engine [n/s/m/l/x]

# Download the labels file
wget -O coco.txt https://raw.githubusercontent.com/amikelive/coco-labels/master/coco-labels-2014_2017.txt

# Run inference
./yolov8_seg -d yolov8n-seg.engine ../images [c/g] coco.txt
# The results are displayed in the console
```

#### Pose

```shell
cd [PATH-TO-TENSORRTX]/yolov8/yolov8_trt10
# Update kNumClass in src/config.h if your model is trained on custom dataset
mkdir build
cd build
cp [PATH-TO-ultralytics-yolov8]/yolov8n-pose.wts .
cmake ..
make

# Build and serialize TensorRT engine
./yolov8_seg -s yolov8n-pose.wts yolov8n-pose.engine [n/s/m/l/x]

# Run inference
./yolov8_seg -d yolov8n-seg.engine ../images c
# The results are displayed in the console
```

## INT8 Quantization
1. Prepare calibration images, you can randomly select 1000s images from your train set. For coco, you can also download my calibration images `coco_calib` from [GoogleDrive](https://drive.google.com/drive/folders/1s7jE9DtOngZMzJC1uL307J2MiaGwdRSI?usp=sharing) or [BaiduPan](https://pan.baidu.com/s/1GOm_-JobpyLMAqZWCDUhKg) pwd: a9wh
2. unzip it in yolov8_trt10/build
3. set the macro `USE_INT8` in src/config.h and make again
4. serialize the model and test

## More Information
See the readme in [home page.](https://github.com/wang-xinyu/tensorrtx)
57 changes: 57 additions & 0 deletions yolov8/yolov8_trt10/gen_wts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import sys # noqa: F401
import argparse
import os
import struct
import torch


def parse_args():
parser = argparse.ArgumentParser(description='Convert .pt file to .wts')
parser.add_argument('-w', '--weights', required=True,
help='Input weights (.pt) file path (required)')
parser.add_argument(
'-o', '--output', help='Output (.wts) file path (optional)')
parser.add_argument(
'-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg', 'pose'],
help='determines the model is detection/classification')
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit('Invalid input file')
if not args.output:
args.output = os.path.splitext(args.weights)[0] + '.wts'
elif os.path.isdir(args.output):
args.output = os.path.join(
args.output,
os.path.splitext(os.path.basename(args.weights))[0] + '.wts')
return args.weights, args.output, args.type


pt_file, wts_file, m_type = parse_args()

print(f'Generating .wts for {m_type} model')

# Load model
print(f'Loading {pt_file}')

# Initialize
device = 'cpu'

# Load model
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32

if m_type in ['detect', 'seg', 'pose']:
anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]

delattr(model.model[-1], 'anchors')

model.to(device).eval()

with open(wts_file, 'w') as f:
f.write('{}\n'.format(len(model.state_dict().keys())))
for k, v in model.state_dict().items():
vr = v.reshape(-1).cpu().numpy()
f.write('{} {} '.format(k, len(vr)))
for vv in vr:
f.write(' ')
f.write(struct.pack('>f', float(vv)).hex())
f.write('\n')
35 changes: 35 additions & 0 deletions yolov8/yolov8_trt10/include/block.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include <map>
#include <string>
#include <vector>
#include "NvInfer.h"

std::map<std::string, nvinfer1::Weights> loadWeights(const std::string file);

nvinfer1::IScaleLayer* addBatchNorm2d(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input,
std::string lname, float eps);

nvinfer1::IElementWiseLayer* convBnSiLU(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input,
int ch, int k, int s, int p, std::string lname);

nvinfer1::IElementWiseLayer* C2F(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input, int c1,
int c2, int n, bool shortcut, float e, std::string lname);

nvinfer1::IElementWiseLayer* C2(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights>& weightMap, nvinfer1::ITensor& input, int c1,
int c2, int n, bool shortcut, float e, std::string lname);

nvinfer1::IElementWiseLayer* SPPF(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input, int c1,
int c2, int k, std::string lname);

nvinfer1::IShuffleLayer* DFL(nvinfer1::INetworkDefinition* network, std::map<std::string, nvinfer1::Weights> weightMap,
nvinfer1::ITensor& input, int ch, int grid, int k, int s, int p, std::string lname);

nvinfer1::IPluginV2Layer* addYoLoLayer(nvinfer1::INetworkDefinition* network,
std::vector<nvinfer1::IConcatenationLayer*> dets, const int* px_arry,
int px_arry_num, bool is_segmentation, bool is_pose);
39 changes: 39 additions & 0 deletions yolov8/yolov8_trt10/include/calibrator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#ifndef ENTROPY_CALIBRATOR_H
#define ENTROPY_CALIBRATOR_H

#include <NvInfer.h>
#include <string>
#include <vector>
#include "macros.h"

//! \class Int8EntropyCalibrator2
//!
//! \brief Implements Entropy calibrator 2.
//! CalibrationAlgoType is kENTROPY_CALIBRATION_2.
//!
class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 {
public:
Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name,
const char* input_blob_name, bool read_cache = true);
virtual ~Int8EntropyCalibrator2();
int getBatchSize() const TRT_NOEXCEPT override;
bool getBatch(void* bindings[], const char* names[], int nbBindings) TRT_NOEXCEPT override;
const void* readCalibrationCache(size_t& length) TRT_NOEXCEPT override;
void writeCalibrationCache(const void* cache, size_t length) TRT_NOEXCEPT override;

private:
int batchsize_;
int input_w_;
int input_h_;
int img_idx_;
std::string img_dir_;
std::vector<std::string> img_files_;
size_t input_count_;
std::string calib_table_name_;
const char* input_blob_name_;
bool read_cache_;
void* device_input_;
std::vector<char> calib_cache_;
};

#endif // ENTROPY_CALIBRATOR_H
27 changes: 27 additions & 0 deletions yolov8/yolov8_trt10/include/config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// #define USE_FP16
// #define USE_FP32
#define USE_INT8

const static char* kInputTensorName = "images";
const static char* kOutputTensorName = "output";
const static char* kProtoTensorName = "proto";
const static int kNumClass = 80;
const static int kPoseNumClass = 1;
const static int kNumberOfPoints = 17; // number of keypoints total
const static int kBatchSize = 1;
const static int kGpuId = 0;
const static int kInputH = 640;
const static int kInputW = 640;
const static float kNmsThresh = 0.45f;
const static float kConfThresh = 0.5f;
const static float kConfThreshKeypoints = 0.5f; // keypoints confidence
const static int kMaxInputImageSize = 3000 * 3000;
const static int kMaxNumOutputBbox = 1000;
//Quantization input image folder path
const static char* kInputQuantizationFolder = "./coco_calib";

// Classfication model's number of classes
constexpr static int kClsNumClass = 1000;
// Classfication model's input shape
constexpr static int kClsInputH = 224;
constexpr static int kClsInputW = 224;
17 changes: 17 additions & 0 deletions yolov8/yolov8_trt10/include/cuda_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef TRTX_CUDA_UTILS_H_
#define TRTX_CUDA_UTILS_H_

#include <cuda_runtime_api.h>

#ifndef CUDA_CHECK
#define CUDA_CHECK(callstr) \
{ \
cudaError_t error_code = callstr; \
if (error_code != cudaSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#endif // CUDA_CHECK

#endif // TRTX_CUDA_UTILS_H_
Loading

0 comments on commit 54e0858

Please sign in to comment.