-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* The v5-cls model supports TensorRT10 * The v5-cls model supports TensorRT10 Python API * add YOLOv5-cls readme * pre-commit and modify trtx download branch * pre-commit * The v5 det model supports TensorRT10 * import pycuda.autoinit # noqa: F401 * The v5 det model supports TensorRT10 Python API * modeify readme * modefiy readme * modify reamde * Delete the link of nvinfer_plugin * Add TensorRT10 support for YOLOv8 * pre-commit * delelet images, modify add test images in readme
- Loading branch information
Showing
30 changed files
with
7,444 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
cmake_minimum_required(VERSION 3.10) | ||
|
||
project(yolov8) | ||
|
||
add_definitions(-std=c++11) | ||
add_definitions(-DAPI_EXPORTS) | ||
set(CMAKE_CXX_STANDARD 11) | ||
set(CMAKE_BUILD_TYPE Debug) | ||
|
||
set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc) | ||
enable_language(CUDA) | ||
|
||
include_directories(${PROJECT_SOURCE_DIR}/include) | ||
include_directories(${PROJECT_SOURCE_DIR}/plugin) | ||
|
||
# include and link dirs of cuda and tensorrt, you need adapt them if yours are different | ||
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") | ||
message("embed_platform on") | ||
include_directories(/usr/local/cuda/targets/aarch64-linux/include) | ||
link_directories(/usr/local/cuda/targets/aarch64-linux/lib) | ||
else() | ||
message("embed_platform off") | ||
|
||
# cuda | ||
include_directories(/usr/local/cuda/include) | ||
link_directories(/usr/local/cuda/lib64) | ||
|
||
# tensorrt | ||
include_directories(/workspace/shared/TensorRT-10.2.0.19/include/) | ||
link_directories(/workspace/shared/TensorRT-10.2.0.19/lib/) | ||
|
||
# include_directories(/home/lindsay/TensorRT-7.2.3.4/include) | ||
# link_directories(/home/lindsay/TensorRT-7.2.3.4/lib) | ||
endif() | ||
|
||
add_library(myplugins SHARED ${PROJECT_SOURCE_DIR}/plugin/yololayer.cu) | ||
target_link_libraries(myplugins nvinfer cudart) | ||
|
||
find_package(OpenCV) | ||
include_directories(${OpenCV_INCLUDE_DIRS}) | ||
|
||
file(GLOB_RECURSE SRCS ${PROJECT_SOURCE_DIR}/src/*.cpp ${PROJECT_SOURCE_DIR}/src/*.cu) | ||
add_executable(yolov8_det ${PROJECT_SOURCE_DIR}/yolov8_det.cpp ${SRCS}) | ||
|
||
target_link_libraries(yolov8_det nvinfer) | ||
target_link_libraries(yolov8_det cudart) | ||
target_link_libraries(yolov8_det myplugins) | ||
target_link_libraries(yolov8_det ${OpenCV_LIBS}) | ||
|
||
add_executable(yolov8_seg ${PROJECT_SOURCE_DIR}/yolov8_seg.cpp ${SRCS}) | ||
target_link_libraries(yolov8_seg nvinfer cudart myplugins ${OpenCV_LIBS}) | ||
|
||
add_executable(yolov8_pose ${PROJECT_SOURCE_DIR}/yolov8_pose.cpp ${SRCS}) | ||
target_link_libraries(yolov8_pose nvinfer cudart myplugins ${OpenCV_LIBS}) | ||
|
||
add_executable(yolov8_cls ${PROJECT_SOURCE_DIR}/yolov8_cls.cpp ${SRCS}) | ||
target_link_libraries(yolov8_cls nvinfer cudart myplugins ${OpenCV_LIBS}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
## Introduce | ||
|
||
Yolov8 model supports TensorRT-10. | ||
|
||
## Environment | ||
|
||
CUDA: 11.8 | ||
CUDNN: 8.9.1.23 | ||
TensorRT: TensorRT-10.2.0.19 | ||
|
||
## Support | ||
|
||
* [x] YOLOv8-cls support FP32/FP16/INT8 and Python/C++ API | ||
* [x] YOLOv8-det support FP32/FP16/INT8 and Python/C++ API | ||
* [x] YOLOv8-seg support FP32/FP16/INT8 and Python/C++ API | ||
* [x] YOLOv8-pose support FP32/FP16/INT8 and Python/C++ API | ||
|
||
## Config | ||
|
||
* Choose the YOLOv8 sub-model n/s/m/l/x/n6/s6/m6/l6/x6 from command line arguments. | ||
* Other configs please check [src/config.h](src/config.h) | ||
|
||
## Build and Run | ||
|
||
1. generate .wts from pytorch with .pt, or download .wts from model zoo | ||
|
||
```shell | ||
git clone https://gitclone.com/github.com/ultralytics/ultralytics.git | ||
git clone -b trt10 https://github.com/wang-xinyu/tensorrtx.git | ||
cd yolov8/ | ||
wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n-cls.pt | ||
wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt | ||
wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n-seg.pt | ||
wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n-pose.pt | ||
cp [PATH-TO-TENSORRTX]/yolov8/gen_wts.py . | ||
python gen_wts.py -w yolov8n-cls.pt -o yolov8n-cls.wts -t cls | ||
python gen_wts.py -w yolov8n.pt -o yolov8n.wts | ||
python gen_wts.py -w yolov8n-seg.pt -o yolov8n-seg.wts -t seg | ||
python gen_wts.py -w yolov8n-pose.pt -o yolov8n-pose.wts -t pose | ||
# A file 'yolov8n.wts' will be generated. | ||
``` | ||
|
||
2. build tensorrtx/yolov8/yolov8_trt10 and run | ||
|
||
#### Classification | ||
|
||
```shell | ||
cd [PATH-TO-TENSORRTX]/yolov8/yolov8_trt10 | ||
|
||
# add test images | ||
mkdir images | ||
cp [PATH-TO-TENSORRTX]/yolov3-spp/samples/*.jpg ./images | ||
|
||
# Update kNumClass in src/config.h if your model is trained on custom dataset | ||
mkdir build | ||
cd build | ||
cp [PATH-TO-ultralytics-yolov8]/yolov8sn-cls.wts . | ||
cmake .. | ||
make | ||
|
||
# Download ImageNet labels | ||
wget https://github.com/joannzhang00/ImageNet-dataset-classes-labels/blob/main/imagenet_classes.txt | ||
|
||
# Build and serialize TensorRT engine | ||
./yolov8_cls -s yolov8n-cls.wts yolov8n-cls.engine [n/s/m/l/x] | ||
|
||
# Run inference | ||
./yolov8_cls -d yolov8n-cls.engine ../images | ||
# The results are displayed in the console | ||
``` | ||
|
||
3. Optional, load and run the tensorrt model in Python | ||
```shell | ||
// Install python-tensorrt, pycuda, etc. | ||
// Ensure the yolov8n-cls.engine | ||
python yolov8_cls_trt.py ./build/yolov8n-cls.engine ../images | ||
# faq: in windows bug pycuda._driver.LogicError | ||
# faq: in linux bug Segmentation fault | ||
# Add the following code to the py file: | ||
# import pycuda.autoinit | ||
# import pycuda.driver as cuda | ||
``` | ||
|
||
#### Detection | ||
|
||
```shell | ||
cd [PATH-TO-TENSORRTX]/yolov8/yolov8_trt10 | ||
# Update kNumClass in src/config.h if your model is trained on custom dataset | ||
mkdir build | ||
cd build | ||
cp [PATH-TO-ultralytics-yolov8]/yolov8n.wts . | ||
cmake .. | ||
make | ||
|
||
# Build and serialize TensorRT engine | ||
./yolov8_det -s yolov8n.wts yolov8n.engine [n/s/m/l/x] | ||
|
||
# Run inference | ||
./yolov8_det -d yolov8n.engine ../images [c/g] | ||
# The results are displayed in the console | ||
``` | ||
|
||
#### Segmentation | ||
|
||
```shell | ||
cd [PATH-TO-TENSORRTX]/yolov8/yolov8_trt10 | ||
# Update kNumClass in src/config.h if your model is trained on custom dataset | ||
mkdir build | ||
cd build | ||
cp [PATH-TO-ultralytics-yolov8]/yolov8n-seg.wts . | ||
cmake .. | ||
make | ||
|
||
# Build and serialize TensorRT engine | ||
./yolov8_seg -s yolov8n-seg.wts yolov8n-seg.engine [n/s/m/l/x] | ||
|
||
# Download the labels file | ||
wget -O coco.txt https://raw.githubusercontent.com/amikelive/coco-labels/master/coco-labels-2014_2017.txt | ||
|
||
# Run inference | ||
./yolov8_seg -d yolov8n-seg.engine ../images [c/g] coco.txt | ||
# The results are displayed in the console | ||
``` | ||
|
||
#### Pose | ||
|
||
```shell | ||
cd [PATH-TO-TENSORRTX]/yolov8/yolov8_trt10 | ||
# Update kNumClass in src/config.h if your model is trained on custom dataset | ||
mkdir build | ||
cd build | ||
cp [PATH-TO-ultralytics-yolov8]/yolov8n-pose.wts . | ||
cmake .. | ||
make | ||
|
||
# Build and serialize TensorRT engine | ||
./yolov8_seg -s yolov8n-pose.wts yolov8n-pose.engine [n/s/m/l/x] | ||
|
||
# Run inference | ||
./yolov8_seg -d yolov8n-seg.engine ../images c | ||
# The results are displayed in the console | ||
``` | ||
|
||
## INT8 Quantization | ||
1. Prepare calibration images, you can randomly select 1000s images from your train set. For coco, you can also download my calibration images `coco_calib` from [GoogleDrive](https://drive.google.com/drive/folders/1s7jE9DtOngZMzJC1uL307J2MiaGwdRSI?usp=sharing) or [BaiduPan](https://pan.baidu.com/s/1GOm_-JobpyLMAqZWCDUhKg) pwd: a9wh | ||
2. unzip it in yolov8_trt10/build | ||
3. set the macro `USE_INT8` in src/config.h and make again | ||
4. serialize the model and test | ||
|
||
## More Information | ||
See the readme in [home page.](https://github.com/wang-xinyu/tensorrtx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import sys # noqa: F401 | ||
import argparse | ||
import os | ||
import struct | ||
import torch | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Convert .pt file to .wts') | ||
parser.add_argument('-w', '--weights', required=True, | ||
help='Input weights (.pt) file path (required)') | ||
parser.add_argument( | ||
'-o', '--output', help='Output (.wts) file path (optional)') | ||
parser.add_argument( | ||
'-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg', 'pose'], | ||
help='determines the model is detection/classification') | ||
args = parser.parse_args() | ||
if not os.path.isfile(args.weights): | ||
raise SystemExit('Invalid input file') | ||
if not args.output: | ||
args.output = os.path.splitext(args.weights)[0] + '.wts' | ||
elif os.path.isdir(args.output): | ||
args.output = os.path.join( | ||
args.output, | ||
os.path.splitext(os.path.basename(args.weights))[0] + '.wts') | ||
return args.weights, args.output, args.type | ||
|
||
|
||
pt_file, wts_file, m_type = parse_args() | ||
|
||
print(f'Generating .wts for {m_type} model') | ||
|
||
# Load model | ||
print(f'Loading {pt_file}') | ||
|
||
# Initialize | ||
device = 'cpu' | ||
|
||
# Load model | ||
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32 | ||
|
||
if m_type in ['detect', 'seg', 'pose']: | ||
anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None] | ||
|
||
delattr(model.model[-1], 'anchors') | ||
|
||
model.to(device).eval() | ||
|
||
with open(wts_file, 'w') as f: | ||
f.write('{}\n'.format(len(model.state_dict().keys()))) | ||
for k, v in model.state_dict().items(): | ||
vr = v.reshape(-1).cpu().numpy() | ||
f.write('{} {} '.format(k, len(vr))) | ||
for vv in vr: | ||
f.write(' ') | ||
f.write(struct.pack('>f', float(vv)).hex()) | ||
f.write('\n') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#pragma once | ||
|
||
#include <map> | ||
#include <string> | ||
#include <vector> | ||
#include "NvInfer.h" | ||
|
||
std::map<std::string, nvinfer1::Weights> loadWeights(const std::string file); | ||
|
||
nvinfer1::IScaleLayer* addBatchNorm2d(nvinfer1::INetworkDefinition* network, | ||
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input, | ||
std::string lname, float eps); | ||
|
||
nvinfer1::IElementWiseLayer* convBnSiLU(nvinfer1::INetworkDefinition* network, | ||
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input, | ||
int ch, int k, int s, int p, std::string lname); | ||
|
||
nvinfer1::IElementWiseLayer* C2F(nvinfer1::INetworkDefinition* network, | ||
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input, int c1, | ||
int c2, int n, bool shortcut, float e, std::string lname); | ||
|
||
nvinfer1::IElementWiseLayer* C2(nvinfer1::INetworkDefinition* network, | ||
std::map<std::string, nvinfer1::Weights>& weightMap, nvinfer1::ITensor& input, int c1, | ||
int c2, int n, bool shortcut, float e, std::string lname); | ||
|
||
nvinfer1::IElementWiseLayer* SPPF(nvinfer1::INetworkDefinition* network, | ||
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input, int c1, | ||
int c2, int k, std::string lname); | ||
|
||
nvinfer1::IShuffleLayer* DFL(nvinfer1::INetworkDefinition* network, std::map<std::string, nvinfer1::Weights> weightMap, | ||
nvinfer1::ITensor& input, int ch, int grid, int k, int s, int p, std::string lname); | ||
|
||
nvinfer1::IPluginV2Layer* addYoLoLayer(nvinfer1::INetworkDefinition* network, | ||
std::vector<nvinfer1::IConcatenationLayer*> dets, const int* px_arry, | ||
int px_arry_num, bool is_segmentation, bool is_pose); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#ifndef ENTROPY_CALIBRATOR_H | ||
#define ENTROPY_CALIBRATOR_H | ||
|
||
#include <NvInfer.h> | ||
#include <string> | ||
#include <vector> | ||
#include "macros.h" | ||
|
||
//! \class Int8EntropyCalibrator2 | ||
//! | ||
//! \brief Implements Entropy calibrator 2. | ||
//! CalibrationAlgoType is kENTROPY_CALIBRATION_2. | ||
//! | ||
class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 { | ||
public: | ||
Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, | ||
const char* input_blob_name, bool read_cache = true); | ||
virtual ~Int8EntropyCalibrator2(); | ||
int getBatchSize() const TRT_NOEXCEPT override; | ||
bool getBatch(void* bindings[], const char* names[], int nbBindings) TRT_NOEXCEPT override; | ||
const void* readCalibrationCache(size_t& length) TRT_NOEXCEPT override; | ||
void writeCalibrationCache(const void* cache, size_t length) TRT_NOEXCEPT override; | ||
|
||
private: | ||
int batchsize_; | ||
int input_w_; | ||
int input_h_; | ||
int img_idx_; | ||
std::string img_dir_; | ||
std::vector<std::string> img_files_; | ||
size_t input_count_; | ||
std::string calib_table_name_; | ||
const char* input_blob_name_; | ||
bool read_cache_; | ||
void* device_input_; | ||
std::vector<char> calib_cache_; | ||
}; | ||
|
||
#endif // ENTROPY_CALIBRATOR_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// #define USE_FP16 | ||
// #define USE_FP32 | ||
#define USE_INT8 | ||
|
||
const static char* kInputTensorName = "images"; | ||
const static char* kOutputTensorName = "output"; | ||
const static char* kProtoTensorName = "proto"; | ||
const static int kNumClass = 80; | ||
const static int kPoseNumClass = 1; | ||
const static int kNumberOfPoints = 17; // number of keypoints total | ||
const static int kBatchSize = 1; | ||
const static int kGpuId = 0; | ||
const static int kInputH = 640; | ||
const static int kInputW = 640; | ||
const static float kNmsThresh = 0.45f; | ||
const static float kConfThresh = 0.5f; | ||
const static float kConfThreshKeypoints = 0.5f; // keypoints confidence | ||
const static int kMaxInputImageSize = 3000 * 3000; | ||
const static int kMaxNumOutputBbox = 1000; | ||
//Quantization input image folder path | ||
const static char* kInputQuantizationFolder = "./coco_calib"; | ||
|
||
// Classfication model's number of classes | ||
constexpr static int kClsNumClass = 1000; | ||
// Classfication model's input shape | ||
constexpr static int kClsInputH = 224; | ||
constexpr static int kClsInputW = 224; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#ifndef TRTX_CUDA_UTILS_H_ | ||
#define TRTX_CUDA_UTILS_H_ | ||
|
||
#include <cuda_runtime_api.h> | ||
|
||
#ifndef CUDA_CHECK | ||
#define CUDA_CHECK(callstr) \ | ||
{ \ | ||
cudaError_t error_code = callstr; \ | ||
if (error_code != cudaSuccess) { \ | ||
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ | ||
assert(0); \ | ||
} \ | ||
} | ||
#endif // CUDA_CHECK | ||
|
||
#endif // TRTX_CUDA_UTILS_H_ |
Oops, something went wrong.