Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trt10 yolov8 #1557

Merged
merged 17 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions yolov8/yolov8_trt10/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
cmake_minimum_required(VERSION 3.10)

project(yolov8)

add_definitions(-std=c++11)
add_definitions(-DAPI_EXPORTS)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_BUILD_TYPE Debug)

set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc)
enable_language(CUDA)

include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_SOURCE_DIR}/plugin)

# include and link dirs of cuda and tensorrt, you need adapt them if yours are different
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
message("embed_platform on")
include_directories(/usr/local/cuda/targets/aarch64-linux/include)
link_directories(/usr/local/cuda/targets/aarch64-linux/lib)
else()
message("embed_platform off")

# cuda
include_directories(/usr/local/cuda/include)
link_directories(/usr/local/cuda/lib64)

# tensorrt
include_directories(/workspace/shared/TensorRT-10.2.0.19/include/)
link_directories(/workspace/shared/TensorRT-10.2.0.19/lib/)

# include_directories(/home/lindsay/TensorRT-7.2.3.4/include)
# link_directories(/home/lindsay/TensorRT-7.2.3.4/lib)
endif()

add_library(myplugins SHARED ${PROJECT_SOURCE_DIR}/plugin/yololayer.cu)
target_link_libraries(myplugins nvinfer cudart)

find_package(OpenCV)
include_directories(${OpenCV_INCLUDE_DIRS})

file(GLOB_RECURSE SRCS ${PROJECT_SOURCE_DIR}/src/*.cpp ${PROJECT_SOURCE_DIR}/src/*.cu)
add_executable(yolov8_det ${PROJECT_SOURCE_DIR}/yolov8_det.cpp ${SRCS})

target_link_libraries(yolov8_det nvinfer)
target_link_libraries(yolov8_det cudart)
target_link_libraries(yolov8_det myplugins)
target_link_libraries(yolov8_det ${OpenCV_LIBS})

add_executable(yolov8_seg ${PROJECT_SOURCE_DIR}/yolov8_seg.cpp ${SRCS})
target_link_libraries(yolov8_seg nvinfer cudart myplugins ${OpenCV_LIBS})

add_executable(yolov8_pose ${PROJECT_SOURCE_DIR}/yolov8_pose.cpp ${SRCS})
target_link_libraries(yolov8_pose nvinfer cudart myplugins ${OpenCV_LIBS})

add_executable(yolov8_cls ${PROJECT_SOURCE_DIR}/yolov8_cls.cpp ${SRCS})
target_link_libraries(yolov8_cls nvinfer cudart myplugins ${OpenCV_LIBS})
151 changes: 151 additions & 0 deletions yolov8/yolov8_trt10/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
## Introduce

Yolov8 model supports TensorRT-10.

## Environment

CUDA: 11.8
CUDNN: 8.9.1.23
TensorRT: TensorRT-10.2.0.19

## Support

* [x] YOLOv8-cls support FP32/FP16/INT8 and Python/C++ API
* [x] YOLOv8-det support FP32/FP16/INT8 and Python/C++ API
* [x] YOLOv8-seg support FP32/FP16/INT8 and Python/C++ API
* [x] YOLOv8-pose support FP32/FP16/INT8 and Python/C++ API

## Config

* Choose the YOLOv8 sub-model n/s/m/l/x/n6/s6/m6/l6/x6 from command line arguments.
* Other configs please check [src/config.h](src/config.h)

## Build and Run

1. generate .wts from pytorch with .pt, or download .wts from model zoo

```shell
git clone https://gitclone.com/github.com/ultralytics/ultralytics.git
git clone -b trt10 https://github.com/wang-xinyu/tensorrtx.git
cd yolov8/
wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n-cls.pt
wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt
wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n-seg.pt
wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n-pose.pt
cp [PATH-TO-TENSORRTX]/yolov8/gen_wts.py .
python gen_wts.py -w yolov8n-cls.pt -o yolov8n-cls.wts -t cls
python gen_wts.py -w yolov8n.pt -o yolov8n.wts
python gen_wts.py -w yolov8n-seg.pt -o yolov8n-seg.wts -t seg
python gen_wts.py -w yolov8n-pose.pt -o yolov8n-pose.wts -t pose
# A file 'yolov8n.wts' will be generated.
```

2. build tensorrtx/yolov8/yolov8_trt10 and run

#### Classification

```shell
cd [PATH-TO-TENSORRTX]/yolov8/yolov8_trt10

# add test images
mkdir images
cp [PATH-TO-TENSORRTX]/yolov3-spp/samples/*.jpg ./images

# Update kNumClass in src/config.h if your model is trained on custom dataset
mkdir build
cd build
cp [PATH-TO-ultralytics-yolov8]/yolov8sn-cls.wts .
cmake ..
make

# Download ImageNet labels
wget https://github.com/joannzhang00/ImageNet-dataset-classes-labels/blob/main/imagenet_classes.txt

# Build and serialize TensorRT engine
./yolov8_cls -s yolov8n-cls.wts yolov8n-cls.engine [n/s/m/l/x]

# Run inference
./yolov8_cls -d yolov8n-cls.engine ../images
# The results are displayed in the console
```

3. Optional, load and run the tensorrt model in Python
```shell
// Install python-tensorrt, pycuda, etc.
// Ensure the yolov8n-cls.engine
python yolov8_cls_trt.py ./build/yolov8n-cls.engine ../images
# faq: in windows bug pycuda._driver.LogicError
# faq: in linux bug Segmentation fault
# Add the following code to the py file:
# import pycuda.autoinit
# import pycuda.driver as cuda
```

#### Detection

```shell
cd [PATH-TO-TENSORRTX]/yolov8/yolov8_trt10
# Update kNumClass in src/config.h if your model is trained on custom dataset
mkdir build
cd build
cp [PATH-TO-ultralytics-yolov8]/yolov8n.wts .
cmake ..
make

# Build and serialize TensorRT engine
./yolov8_det -s yolov8n.wts yolov8n.engine [n/s/m/l/x]

# Run inference
./yolov8_det -d yolov8n.engine ../images [c/g]
# The results are displayed in the console
```

#### Segmentation

```shell
cd [PATH-TO-TENSORRTX]/yolov8/yolov8_trt10
# Update kNumClass in src/config.h if your model is trained on custom dataset
mkdir build
cd build
cp [PATH-TO-ultralytics-yolov8]/yolov8n-seg.wts .
cmake ..
make

# Build and serialize TensorRT engine
./yolov8_seg -s yolov8n-seg.wts yolov8n-seg.engine [n/s/m/l/x]

# Download the labels file
wget -O coco.txt https://raw.githubusercontent.com/amikelive/coco-labels/master/coco-labels-2014_2017.txt

# Run inference
./yolov8_seg -d yolov8n-seg.engine ../images [c/g] coco.txt
# The results are displayed in the console
```

#### Pose

```shell
cd [PATH-TO-TENSORRTX]/yolov8/yolov8_trt10
# Update kNumClass in src/config.h if your model is trained on custom dataset
mkdir build
cd build
cp [PATH-TO-ultralytics-yolov8]/yolov8n-pose.wts .
cmake ..
make

# Build and serialize TensorRT engine
./yolov8_seg -s yolov8n-pose.wts yolov8n-pose.engine [n/s/m/l/x]

# Run inference
./yolov8_seg -d yolov8n-seg.engine ../images c
# The results are displayed in the console
```

## INT8 Quantization
1. Prepare calibration images, you can randomly select 1000s images from your train set. For coco, you can also download my calibration images `coco_calib` from [GoogleDrive](https://drive.google.com/drive/folders/1s7jE9DtOngZMzJC1uL307J2MiaGwdRSI?usp=sharing) or [BaiduPan](https://pan.baidu.com/s/1GOm_-JobpyLMAqZWCDUhKg) pwd: a9wh
2. unzip it in yolov8_trt10/build
3. set the macro `USE_INT8` in src/config.h and make again
4. serialize the model and test

## More Information
See the readme in [home page.](https://github.com/wang-xinyu/tensorrtx)
57 changes: 57 additions & 0 deletions yolov8/yolov8_trt10/gen_wts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import sys # noqa: F401
import argparse
import os
import struct
import torch


def parse_args():
parser = argparse.ArgumentParser(description='Convert .pt file to .wts')
parser.add_argument('-w', '--weights', required=True,
help='Input weights (.pt) file path (required)')
parser.add_argument(
'-o', '--output', help='Output (.wts) file path (optional)')
parser.add_argument(
'-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg', 'pose'],
help='determines the model is detection/classification')
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit('Invalid input file')
if not args.output:
args.output = os.path.splitext(args.weights)[0] + '.wts'
elif os.path.isdir(args.output):
args.output = os.path.join(
args.output,
os.path.splitext(os.path.basename(args.weights))[0] + '.wts')
return args.weights, args.output, args.type


pt_file, wts_file, m_type = parse_args()

print(f'Generating .wts for {m_type} model')

# Load model
print(f'Loading {pt_file}')

# Initialize
device = 'cpu'

# Load model
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32

if m_type in ['detect', 'seg', 'pose']:
anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]

delattr(model.model[-1], 'anchors')

model.to(device).eval()

with open(wts_file, 'w') as f:
f.write('{}\n'.format(len(model.state_dict().keys())))
for k, v in model.state_dict().items():
vr = v.reshape(-1).cpu().numpy()
f.write('{} {} '.format(k, len(vr)))
for vv in vr:
f.write(' ')
f.write(struct.pack('>f', float(vv)).hex())
f.write('\n')
35 changes: 35 additions & 0 deletions yolov8/yolov8_trt10/include/block.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include <map>
#include <string>
#include <vector>
#include "NvInfer.h"

std::map<std::string, nvinfer1::Weights> loadWeights(const std::string file);

nvinfer1::IScaleLayer* addBatchNorm2d(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input,
std::string lname, float eps);

nvinfer1::IElementWiseLayer* convBnSiLU(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input,
int ch, int k, int s, int p, std::string lname);

nvinfer1::IElementWiseLayer* C2F(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input, int c1,
int c2, int n, bool shortcut, float e, std::string lname);

nvinfer1::IElementWiseLayer* C2(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights>& weightMap, nvinfer1::ITensor& input, int c1,
int c2, int n, bool shortcut, float e, std::string lname);

nvinfer1::IElementWiseLayer* SPPF(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input, int c1,
int c2, int k, std::string lname);

nvinfer1::IShuffleLayer* DFL(nvinfer1::INetworkDefinition* network, std::map<std::string, nvinfer1::Weights> weightMap,
nvinfer1::ITensor& input, int ch, int grid, int k, int s, int p, std::string lname);

nvinfer1::IPluginV2Layer* addYoLoLayer(nvinfer1::INetworkDefinition* network,
std::vector<nvinfer1::IConcatenationLayer*> dets, const int* px_arry,
int px_arry_num, bool is_segmentation, bool is_pose);
39 changes: 39 additions & 0 deletions yolov8/yolov8_trt10/include/calibrator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#ifndef ENTROPY_CALIBRATOR_H
#define ENTROPY_CALIBRATOR_H

#include <NvInfer.h>
#include <string>
#include <vector>
#include "macros.h"

//! \class Int8EntropyCalibrator2
//!
//! \brief Implements Entropy calibrator 2.
//! CalibrationAlgoType is kENTROPY_CALIBRATION_2.
//!
class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 {
public:
Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name,
const char* input_blob_name, bool read_cache = true);
virtual ~Int8EntropyCalibrator2();
int getBatchSize() const TRT_NOEXCEPT override;
bool getBatch(void* bindings[], const char* names[], int nbBindings) TRT_NOEXCEPT override;
const void* readCalibrationCache(size_t& length) TRT_NOEXCEPT override;
void writeCalibrationCache(const void* cache, size_t length) TRT_NOEXCEPT override;

private:
int batchsize_;
int input_w_;
int input_h_;
int img_idx_;
std::string img_dir_;
std::vector<std::string> img_files_;
size_t input_count_;
std::string calib_table_name_;
const char* input_blob_name_;
bool read_cache_;
void* device_input_;
std::vector<char> calib_cache_;
};

#endif // ENTROPY_CALIBRATOR_H
27 changes: 27 additions & 0 deletions yolov8/yolov8_trt10/include/config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// #define USE_FP16
// #define USE_FP32
#define USE_INT8

const static char* kInputTensorName = "images";
const static char* kOutputTensorName = "output";
const static char* kProtoTensorName = "proto";
const static int kNumClass = 80;
const static int kPoseNumClass = 1;
const static int kNumberOfPoints = 17; // number of keypoints total
const static int kBatchSize = 1;
const static int kGpuId = 0;
const static int kInputH = 640;
const static int kInputW = 640;
const static float kNmsThresh = 0.45f;
const static float kConfThresh = 0.5f;
const static float kConfThreshKeypoints = 0.5f; // keypoints confidence
const static int kMaxInputImageSize = 3000 * 3000;
const static int kMaxNumOutputBbox = 1000;
//Quantization input image folder path
const static char* kInputQuantizationFolder = "./coco_calib";

// Classfication model's number of classes
constexpr static int kClsNumClass = 1000;
// Classfication model's input shape
constexpr static int kClsInputH = 224;
constexpr static int kClsInputW = 224;
17 changes: 17 additions & 0 deletions yolov8/yolov8_trt10/include/cuda_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef TRTX_CUDA_UTILS_H_
#define TRTX_CUDA_UTILS_H_

#include <cuda_runtime_api.h>

#ifndef CUDA_CHECK
#define CUDA_CHECK(callstr) \
{ \
cudaError_t error_code = callstr; \
if (error_code != cudaSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#endif // CUDA_CHECK

#endif // TRTX_CUDA_UTILS_H_
Loading
Loading