forked from wang-xinyu/tensorrtx
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* IBN-Net InstanceNorm2d resnet50-ibna resnet50-ibnb * add ibnnet pytorch repo
- Loading branch information
Showing
14 changed files
with
1,489 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
cmake_minimum_required(VERSION 2.6) | ||
|
||
project(IBNNet) | ||
|
||
add_definitions(-std=c++11) | ||
|
||
option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) | ||
set(CMAKE_CXX_STANDARD 11) | ||
set(CMAKE_BUILD_TYPE Debug) | ||
|
||
find_package(CUDA REQUIRED) | ||
|
||
include_directories(${PROJECT_SOURCE_DIR}/include) | ||
# include and link dirs of cuda and tensorrt, you need adapt them if yours are different | ||
# cuda | ||
include_directories(/usr/local/cuda/include) | ||
link_directories(/usr/local/cuda/lib64) | ||
# tensorrt | ||
include_directories(/usr/include/x86_64-linux-gnu/) | ||
link_directories(/usr/lib/x86_64-linux-gnu/) | ||
|
||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -pthread -Wall -Ofast -Wfatal-errors -D_MWAITXINTRIN_H_INCLUDED") | ||
|
||
find_package(OpenCV) | ||
include_directories(OpenCV_INCLUDE_DIRS) | ||
|
||
file(GLOB SOURCE_FILES "*.h" "*.cpp") | ||
|
||
add_executable(ibnnet ${SOURCE_FILES}) | ||
target_link_libraries(ibnnet nvinfer) | ||
target_link_libraries(ibnnet cudart) | ||
target_link_libraries(ibnnet ${OpenCV_LIBS}) | ||
|
||
add_definitions(-O2 -pthread) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
#include "InferenceEngine.h" | ||
|
||
namespace trt { | ||
|
||
InferenceEngine::InferenceEngine(const EngineConfig &enginecfg): _engineCfg(enginecfg) { | ||
|
||
assert(_engineCfg.max_batch_size > 0); | ||
|
||
CHECK(cudaSetDevice(_engineCfg.device_id)); | ||
|
||
_runtime = make_holder(nvinfer1::createInferRuntime(gLogger)); | ||
assert(_runtime); | ||
|
||
_engine = make_holder(_runtime->deserializeCudaEngine(_engineCfg.trtModelStream.get(), _engineCfg.stream_size)); | ||
assert(_engine); | ||
|
||
_context = make_holder(_engine->createExecutionContext()); | ||
assert(_context); | ||
|
||
_inputSize = _engineCfg.max_batch_size * 3 * _engineCfg.input_h * _engineCfg.input_w * _depth; | ||
_outputSize = _engineCfg.max_batch_size * _engineCfg.output_size * _depth; | ||
|
||
CHECK(cudaMallocHost((void**)&_data, _inputSize)); | ||
CHECK(cudaMallocHost((void**)&_prob, _outputSize)); | ||
|
||
_streamptr = std::shared_ptr<cudaStream_t>( new cudaStream_t, | ||
[](cudaStream_t* ptr){ | ||
cudaStreamDestroy(*ptr); | ||
if(ptr != nullptr){ | ||
delete ptr; | ||
} | ||
}); | ||
|
||
CHECK(cudaStreamCreate(&*_streamptr.get())); | ||
|
||
// Pointers to input and output device buffers to pass to engine. | ||
// Engine requires exactly IEngine::getNbBindings() number of buffers. | ||
assert(_engine->getNbBindings() == 2); | ||
|
||
// In order to bind the buffers, we need to know the names of the input and output tensors. | ||
// Note that indices are guaranteed to be less than IEngine::getNbBindings() | ||
_inputIndex = _engine->getBindingIndex(_engineCfg.input_name); | ||
_outputIndex = _engine->getBindingIndex(_engineCfg.output_name); | ||
|
||
// Create GPU buffers on device | ||
CHECK(cudaMalloc(&_buffers[_inputIndex], _inputSize)); | ||
CHECK(cudaMalloc(&_buffers[_outputIndex], _outputSize)); | ||
|
||
_inputSize /= _engineCfg.max_batch_size; | ||
_outputSize /= _engineCfg.max_batch_size; | ||
|
||
} | ||
|
||
bool InferenceEngine::doInference(const int inference_batch_size, std::function<void(float*)> preprocessing) { | ||
assert(inference_batch_size <= _engineCfg.max_batch_size); | ||
preprocessing(_data); | ||
CHECK(cudaSetDevice(_engineCfg.device_id)); | ||
CHECK(cudaMemcpyAsync(_buffers[_inputIndex], _data, inference_batch_size * _inputSize, cudaMemcpyHostToDevice, *_streamptr)); | ||
auto status = _context->enqueue(inference_batch_size, _buffers, *_streamptr, nullptr); | ||
CHECK(cudaMemcpyAsync(_prob, _buffers[_outputIndex], inference_batch_size * _outputSize, cudaMemcpyDeviceToHost, *_streamptr)); | ||
CHECK(cudaStreamSynchronize(*_streamptr)); | ||
return status; | ||
} | ||
|
||
InferenceEngine::InferenceEngine(InferenceEngine &&other) noexcept: | ||
_engineCfg(other._engineCfg) | ||
, _data(other._data) | ||
, _prob(other._prob) | ||
, _inputIndex(other._inputIndex) | ||
, _outputIndex(other._outputIndex) | ||
, _inputSize(other._inputSize) | ||
, _outputSize(other._outputSize) | ||
, _runtime(std::move(other._runtime)) | ||
, _engine(std::move(other._engine)) | ||
, _context(std::move(other._context)) | ||
, _streamptr(other._streamptr) { | ||
|
||
_buffers[0] = other._buffers[0]; | ||
_buffers[1] = other._buffers[1]; | ||
other._streamptr.reset(); | ||
other._data = nullptr; | ||
other._prob = nullptr; | ||
other._buffers[0] = nullptr; | ||
other._buffers[1] = nullptr; | ||
} | ||
|
||
InferenceEngine::~InferenceEngine() { | ||
CHECK(cudaFreeHost(_data)); | ||
CHECK(cudaFreeHost(_prob)); | ||
CHECK(cudaFree(_buffers[_inputIndex])); | ||
CHECK(cudaFree(_buffers[_outputIndex])); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
/************************************************************************** | ||
* Handle memory pre-alloc | ||
* both on host(pinned memory, allow CUDA DMA) & device | ||
*************************************************************************/ | ||
|
||
#pragma once | ||
|
||
#include <thread> | ||
#include <chrono> | ||
#include <memory> | ||
#include <functional> | ||
#include <opencv2/opencv.hpp> | ||
|
||
#include "utils.h" | ||
#include "holder.h" | ||
#include "logging.h" | ||
#include "NvInfer.h" | ||
#include "cuda_runtime_api.h" | ||
static Logger gLogger; | ||
|
||
namespace trt { | ||
|
||
struct EngineConfig { | ||
const char* input_name; | ||
const char* output_name; | ||
std::shared_ptr<char> trtModelStream; | ||
int max_batch_size; /* create engine */ | ||
int input_h; | ||
int input_w; | ||
int output_size; | ||
int stream_size; | ||
int device_id; | ||
}; | ||
|
||
class InferenceEngine { | ||
|
||
public: | ||
InferenceEngine(const EngineConfig &enginecfg); | ||
InferenceEngine(InferenceEngine &&other) noexcept; | ||
~InferenceEngine(); | ||
|
||
InferenceEngine(const InferenceEngine &) = delete; | ||
InferenceEngine& operator=(const InferenceEngine &) = delete; | ||
InferenceEngine& operator=(InferenceEngine && other) = delete; | ||
|
||
bool doInference(const int inference_batch_size, std::function<void(float*)> preprocessing); | ||
float* getOutput() { return _prob; } | ||
std::thread::id getThreadID() { return std::this_thread::get_id(); } | ||
|
||
private: | ||
EngineConfig _engineCfg; | ||
float* _data{nullptr}; | ||
float* _prob{nullptr}; | ||
|
||
// Pointers to input and output device buffers to pass to engine. | ||
// Engine requires exactly IEngine::getNbBindings() number of buffers. | ||
void* _buffers[2]; | ||
|
||
// In order to bind the buffers, we need to know the names of the input and output tensors. | ||
// Note that indices are guaranteed to be less than IEngine::getNbBindings() | ||
int _inputIndex; | ||
int _outputIndex; | ||
|
||
int _inputSize; | ||
int _outputSize; | ||
|
||
static constexpr std::size_t _depth{sizeof(float)}; | ||
|
||
TensorRTHolder<nvinfer1::IRuntime> _runtime{nullptr}; | ||
TensorRTHolder<nvinfer1::ICudaEngine> _engine{nullptr}; | ||
TensorRTHolder<nvinfer1::IExecutionContext> _context{nullptr}; | ||
std::shared_ptr<cudaStream_t> _streamptr; | ||
}; | ||
|
||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# IBN-Net | ||
|
||
An implementation of IBN-Net, proposed in ["Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net"](https://arxiv.org/abs/1807.09441), ECCV2018 by Xingang Pan, Ping Luo, Jianping Shi, Xiaoou Tang. | ||
|
||
For the Pytorch implementation, you can refer to [IBN-Net](https://github.com/XingangPan/IBN-Net) | ||
|
||
## Features | ||
- InstanceNorm2d | ||
- bottleneck_ibn | ||
- Resnet50-IBNA | ||
- Resnet50-IBNB | ||
- Multi-thread inference | ||
|
||
## How to Run | ||
|
||
* 1. generate .wts | ||
|
||
// for ibn-a | ||
``` | ||
python gen_wts.py a | ||
``` | ||
a file 'resnet50-ibna.wts' will be generated. | ||
|
||
// for ibn-b | ||
``` | ||
python gen_wts.py b | ||
``` | ||
a file 'resnet50-ibnb.wts' will be generated. | ||
* 2. cmake and make | ||
|
||
``` | ||
mkdir build | ||
cd build | ||
cmake .. | ||
make | ||
``` | ||
* 3. build engine and run classification | ||
|
||
// put resnet50-ibna.wts/resnet50-ibnb.wts into tensorrtx/ibnnet | ||
|
||
// go to tensorrtx/ibnnet | ||
``` | ||
./ibnnet -s // serialize model to plan file | ||
./ibnnet -d // deserialize plan file and run inference | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import torch | ||
import os | ||
import sys | ||
import struct | ||
|
||
|
||
assert sys.argv[1] == "a" or sys.argv[1] == "b" | ||
model_name = "resnet50_ibn_" + sys.argv[1] | ||
|
||
net = torch.hub.load('XingangPan/IBN-Net', model_name, pretrained=True).to('cuda:0').eval() | ||
|
||
#verify | ||
#input = torch.ones(1, 3, 224, 224).to('cuda:0') | ||
#pixel_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1).to('cuda:0') | ||
#pixel_std = torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1).to('cuda:0') | ||
#input.sub_(pixel_mean).div_(pixel_std) | ||
#out = net(input) | ||
#print(out) | ||
|
||
f = open(model_name + ".wts", 'w') | ||
f.write("{}\n".format(len(net.state_dict().keys()))) | ||
for k,v in net.state_dict().items(): | ||
vr = v.reshape(-1).cpu().numpy() | ||
f.write("{} {}".format(k, len(vr))) | ||
for vv in vr: | ||
f.write(" ") | ||
f.write(struct.pack(">f", float(vv)).hex()) | ||
f.write("\n") | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#pragma once | ||
|
||
template <typename T> | ||
class TensorRTHolder { | ||
T* holder; | ||
public: | ||
explicit TensorRTHolder(T* holder_) : holder(holder_) {} | ||
~TensorRTHolder() { | ||
if (holder) | ||
holder->destroy(); | ||
} | ||
TensorRTHolder(const TensorRTHolder&) = delete; | ||
TensorRTHolder& operator=(const TensorRTHolder&) = delete; | ||
TensorRTHolder(TensorRTHolder && rhs) noexcept{ | ||
holder = rhs.holder; | ||
rhs.holder = nullptr; | ||
} | ||
TensorRTHolder& operator=(TensorRTHolder&& rhs) noexcept { | ||
if (this == &rhs) { | ||
return *this; | ||
} | ||
if (holder) holder->destroy(); | ||
holder = rhs.holder; | ||
rhs.holder = nullptr; | ||
return *this; | ||
} | ||
T* operator->() { | ||
return holder; | ||
} | ||
T* get() { return holder; } | ||
explicit operator bool() { return holder != nullptr; } | ||
T& operator*() noexcept { return *holder; } | ||
}; | ||
|
||
template <typename T> | ||
TensorRTHolder<T> make_holder(T* holder) { | ||
return TensorRTHolder<T>(holder); | ||
} | ||
|
||
template <typename T> | ||
using TensorRTNonHolder = T*; |
Oops, something went wrong.