Skip to content

Commit

Permalink
IBN-Net (wang-xinyu#362)
Browse files Browse the repository at this point in the history
* IBN-Net

InstanceNorm2d

resnet50-ibna

resnet50-ibnb

* add ibnnet pytorch repo
  • Loading branch information
TCHeish authored Jan 19, 2021
1 parent 7ba93e2 commit 556665c
Show file tree
Hide file tree
Showing 14 changed files with 1,489 additions and 0 deletions.
35 changes: 35 additions & 0 deletions ibnnet/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
cmake_minimum_required(VERSION 2.6)

project(IBNNet)

add_definitions(-std=c++11)

option(CUDA_USE_STATIC_CUDA_RUNTIME OFF)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_BUILD_TYPE Debug)

find_package(CUDA REQUIRED)

include_directories(${PROJECT_SOURCE_DIR}/include)
# include and link dirs of cuda and tensorrt, you need adapt them if yours are different
# cuda
include_directories(/usr/local/cuda/include)
link_directories(/usr/local/cuda/lib64)
# tensorrt
include_directories(/usr/include/x86_64-linux-gnu/)
link_directories(/usr/lib/x86_64-linux-gnu/)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -pthread -Wall -Ofast -Wfatal-errors -D_MWAITXINTRIN_H_INCLUDED")

find_package(OpenCV)
include_directories(OpenCV_INCLUDE_DIRS)

file(GLOB SOURCE_FILES "*.h" "*.cpp")

add_executable(ibnnet ${SOURCE_FILES})
target_link_libraries(ibnnet nvinfer)
target_link_libraries(ibnnet cudart)
target_link_libraries(ibnnet ${OpenCV_LIBS})

add_definitions(-O2 -pthread)

93 changes: 93 additions & 0 deletions ibnnet/InferenceEngine.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#include "InferenceEngine.h"

namespace trt {

InferenceEngine::InferenceEngine(const EngineConfig &enginecfg): _engineCfg(enginecfg) {

assert(_engineCfg.max_batch_size > 0);

CHECK(cudaSetDevice(_engineCfg.device_id));

_runtime = make_holder(nvinfer1::createInferRuntime(gLogger));
assert(_runtime);

_engine = make_holder(_runtime->deserializeCudaEngine(_engineCfg.trtModelStream.get(), _engineCfg.stream_size));
assert(_engine);

_context = make_holder(_engine->createExecutionContext());
assert(_context);

_inputSize = _engineCfg.max_batch_size * 3 * _engineCfg.input_h * _engineCfg.input_w * _depth;
_outputSize = _engineCfg.max_batch_size * _engineCfg.output_size * _depth;

CHECK(cudaMallocHost((void**)&_data, _inputSize));
CHECK(cudaMallocHost((void**)&_prob, _outputSize));

_streamptr = std::shared_ptr<cudaStream_t>( new cudaStream_t,
[](cudaStream_t* ptr){
cudaStreamDestroy(*ptr);
if(ptr != nullptr){
delete ptr;
}
});

CHECK(cudaStreamCreate(&*_streamptr.get()));

// Pointers to input and output device buffers to pass to engine.
// Engine requires exactly IEngine::getNbBindings() number of buffers.
assert(_engine->getNbBindings() == 2);

// In order to bind the buffers, we need to know the names of the input and output tensors.
// Note that indices are guaranteed to be less than IEngine::getNbBindings()
_inputIndex = _engine->getBindingIndex(_engineCfg.input_name);
_outputIndex = _engine->getBindingIndex(_engineCfg.output_name);

// Create GPU buffers on device
CHECK(cudaMalloc(&_buffers[_inputIndex], _inputSize));
CHECK(cudaMalloc(&_buffers[_outputIndex], _outputSize));

_inputSize /= _engineCfg.max_batch_size;
_outputSize /= _engineCfg.max_batch_size;

}

bool InferenceEngine::doInference(const int inference_batch_size, std::function<void(float*)> preprocessing) {
assert(inference_batch_size <= _engineCfg.max_batch_size);
preprocessing(_data);
CHECK(cudaSetDevice(_engineCfg.device_id));
CHECK(cudaMemcpyAsync(_buffers[_inputIndex], _data, inference_batch_size * _inputSize, cudaMemcpyHostToDevice, *_streamptr));
auto status = _context->enqueue(inference_batch_size, _buffers, *_streamptr, nullptr);
CHECK(cudaMemcpyAsync(_prob, _buffers[_outputIndex], inference_batch_size * _outputSize, cudaMemcpyDeviceToHost, *_streamptr));
CHECK(cudaStreamSynchronize(*_streamptr));
return status;
}

InferenceEngine::InferenceEngine(InferenceEngine &&other) noexcept:
_engineCfg(other._engineCfg)
, _data(other._data)
, _prob(other._prob)
, _inputIndex(other._inputIndex)
, _outputIndex(other._outputIndex)
, _inputSize(other._inputSize)
, _outputSize(other._outputSize)
, _runtime(std::move(other._runtime))
, _engine(std::move(other._engine))
, _context(std::move(other._context))
, _streamptr(other._streamptr) {

_buffers[0] = other._buffers[0];
_buffers[1] = other._buffers[1];
other._streamptr.reset();
other._data = nullptr;
other._prob = nullptr;
other._buffers[0] = nullptr;
other._buffers[1] = nullptr;
}

InferenceEngine::~InferenceEngine() {
CHECK(cudaFreeHost(_data));
CHECK(cudaFreeHost(_prob));
CHECK(cudaFree(_buffers[_inputIndex]));
CHECK(cudaFree(_buffers[_outputIndex]));
}
}
76 changes: 76 additions & 0 deletions ibnnet/InferenceEngine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/**************************************************************************
* Handle memory pre-alloc
* both on host(pinned memory, allow CUDA DMA) & device
*************************************************************************/

#pragma once

#include <thread>
#include <chrono>
#include <memory>
#include <functional>
#include <opencv2/opencv.hpp>

#include "utils.h"
#include "holder.h"
#include "logging.h"
#include "NvInfer.h"
#include "cuda_runtime_api.h"
static Logger gLogger;

namespace trt {

struct EngineConfig {
const char* input_name;
const char* output_name;
std::shared_ptr<char> trtModelStream;
int max_batch_size; /* create engine */
int input_h;
int input_w;
int output_size;
int stream_size;
int device_id;
};

class InferenceEngine {

public:
InferenceEngine(const EngineConfig &enginecfg);
InferenceEngine(InferenceEngine &&other) noexcept;
~InferenceEngine();

InferenceEngine(const InferenceEngine &) = delete;
InferenceEngine& operator=(const InferenceEngine &) = delete;
InferenceEngine& operator=(InferenceEngine && other) = delete;

bool doInference(const int inference_batch_size, std::function<void(float*)> preprocessing);
float* getOutput() { return _prob; }
std::thread::id getThreadID() { return std::this_thread::get_id(); }

private:
EngineConfig _engineCfg;
float* _data{nullptr};
float* _prob{nullptr};

// Pointers to input and output device buffers to pass to engine.
// Engine requires exactly IEngine::getNbBindings() number of buffers.
void* _buffers[2];

// In order to bind the buffers, we need to know the names of the input and output tensors.
// Note that indices are guaranteed to be less than IEngine::getNbBindings()
int _inputIndex;
int _outputIndex;

int _inputSize;
int _outputSize;

static constexpr std::size_t _depth{sizeof(float)};

TensorRTHolder<nvinfer1::IRuntime> _runtime{nullptr};
TensorRTHolder<nvinfer1::ICudaEngine> _engine{nullptr};
TensorRTHolder<nvinfer1::IExecutionContext> _context{nullptr};
std::shared_ptr<cudaStream_t> _streamptr;
};

}

46 changes: 46 additions & 0 deletions ibnnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# IBN-Net

An implementation of IBN-Net, proposed in ["Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net"](https://arxiv.org/abs/1807.09441), ECCV2018 by Xingang Pan, Ping Luo, Jianping Shi, Xiaoou Tang.

For the Pytorch implementation, you can refer to [IBN-Net](https://github.com/XingangPan/IBN-Net)

## Features
- InstanceNorm2d
- bottleneck_ibn
- Resnet50-IBNA
- Resnet50-IBNB
- Multi-thread inference

## How to Run

* 1. generate .wts

// for ibn-a
```
python gen_wts.py a
```
a file 'resnet50-ibna.wts' will be generated.

// for ibn-b
```
python gen_wts.py b
```
a file 'resnet50-ibnb.wts' will be generated.
* 2. cmake and make

```
mkdir build
cd build
cmake ..
make
```
* 3. build engine and run classification

// put resnet50-ibna.wts/resnet50-ibnb.wts into tensorrtx/ibnnet

// go to tensorrtx/ibnnet
```
./ibnnet -s // serialize model to plan file
./ibnnet -d // deserialize plan file and run inference
```

30 changes: 30 additions & 0 deletions ibnnet/gen_wts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
import os
import sys
import struct


assert sys.argv[1] == "a" or sys.argv[1] == "b"
model_name = "resnet50_ibn_" + sys.argv[1]

net = torch.hub.load('XingangPan/IBN-Net', model_name, pretrained=True).to('cuda:0').eval()

#verify
#input = torch.ones(1, 3, 224, 224).to('cuda:0')
#pixel_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1).to('cuda:0')
#pixel_std = torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1).to('cuda:0')
#input.sub_(pixel_mean).div_(pixel_std)
#out = net(input)
#print(out)

f = open(model_name + ".wts", 'w')
f.write("{}\n".format(len(net.state_dict().keys())))
for k,v in net.state_dict().items():
vr = v.reshape(-1).cpu().numpy()
f.write("{} {}".format(k, len(vr)))
for vv in vr:
f.write(" ")
f.write(struct.pack(">f", float(vv)).hex())
f.write("\n")


41 changes: 41 additions & 0 deletions ibnnet/holder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#pragma once

template <typename T>
class TensorRTHolder {
T* holder;
public:
explicit TensorRTHolder(T* holder_) : holder(holder_) {}
~TensorRTHolder() {
if (holder)
holder->destroy();
}
TensorRTHolder(const TensorRTHolder&) = delete;
TensorRTHolder& operator=(const TensorRTHolder&) = delete;
TensorRTHolder(TensorRTHolder && rhs) noexcept{
holder = rhs.holder;
rhs.holder = nullptr;
}
TensorRTHolder& operator=(TensorRTHolder&& rhs) noexcept {
if (this == &rhs) {
return *this;
}
if (holder) holder->destroy();
holder = rhs.holder;
rhs.holder = nullptr;
return *this;
}
T* operator->() {
return holder;
}
T* get() { return holder; }
explicit operator bool() { return holder != nullptr; }
T& operator*() noexcept { return *holder; }
};

template <typename T>
TensorRTHolder<T> make_holder(T* holder) {
return TensorRTHolder<T>(holder);
}

template <typename T>
using TensorRTNonHolder = T*;
Loading

0 comments on commit 556665c

Please sign in to comment.