diff --git a/ibnnet/CMakeLists.txt b/ibnnet/CMakeLists.txt new file mode 100644 index 00000000..93b1132e --- /dev/null +++ b/ibnnet/CMakeLists.txt @@ -0,0 +1,35 @@ +cmake_minimum_required(VERSION 2.6) + +project(IBNNet) + +add_definitions(-std=c++11) + +option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_BUILD_TYPE Debug) + +find_package(CUDA REQUIRED) + +include_directories(${PROJECT_SOURCE_DIR}/include) +# include and link dirs of cuda and tensorrt, you need adapt them if yours are different +# cuda +include_directories(/usr/local/cuda/include) +link_directories(/usr/local/cuda/lib64) +# tensorrt +include_directories(/usr/include/x86_64-linux-gnu/) +link_directories(/usr/lib/x86_64-linux-gnu/) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -pthread -Wall -Ofast -Wfatal-errors -D_MWAITXINTRIN_H_INCLUDED") + +find_package(OpenCV) +include_directories(OpenCV_INCLUDE_DIRS) + +file(GLOB SOURCE_FILES "*.h" "*.cpp") + +add_executable(ibnnet ${SOURCE_FILES}) +target_link_libraries(ibnnet nvinfer) +target_link_libraries(ibnnet cudart) +target_link_libraries(ibnnet ${OpenCV_LIBS}) + +add_definitions(-O2 -pthread) + diff --git a/ibnnet/InferenceEngine.cpp b/ibnnet/InferenceEngine.cpp new file mode 100755 index 00000000..ef69e116 --- /dev/null +++ b/ibnnet/InferenceEngine.cpp @@ -0,0 +1,93 @@ +#include "InferenceEngine.h" + +namespace trt { + + InferenceEngine::InferenceEngine(const EngineConfig &enginecfg): _engineCfg(enginecfg) { + + assert(_engineCfg.max_batch_size > 0); + + CHECK(cudaSetDevice(_engineCfg.device_id)); + + _runtime = make_holder(nvinfer1::createInferRuntime(gLogger)); + assert(_runtime); + + _engine = make_holder(_runtime->deserializeCudaEngine(_engineCfg.trtModelStream.get(), _engineCfg.stream_size)); + assert(_engine); + + _context = make_holder(_engine->createExecutionContext()); + assert(_context); + + _inputSize = _engineCfg.max_batch_size * 3 * _engineCfg.input_h * _engineCfg.input_w * _depth; + _outputSize = _engineCfg.max_batch_size * _engineCfg.output_size * _depth; + + CHECK(cudaMallocHost((void**)&_data, _inputSize)); + CHECK(cudaMallocHost((void**)&_prob, _outputSize)); + + _streamptr = std::shared_ptr( new cudaStream_t, + [](cudaStream_t* ptr){ + cudaStreamDestroy(*ptr); + if(ptr != nullptr){ + delete ptr; + } + }); + + CHECK(cudaStreamCreate(&*_streamptr.get())); + + // Pointers to input and output device buffers to pass to engine. + // Engine requires exactly IEngine::getNbBindings() number of buffers. + assert(_engine->getNbBindings() == 2); + + // In order to bind the buffers, we need to know the names of the input and output tensors. + // Note that indices are guaranteed to be less than IEngine::getNbBindings() + _inputIndex = _engine->getBindingIndex(_engineCfg.input_name); + _outputIndex = _engine->getBindingIndex(_engineCfg.output_name); + + // Create GPU buffers on device + CHECK(cudaMalloc(&_buffers[_inputIndex], _inputSize)); + CHECK(cudaMalloc(&_buffers[_outputIndex], _outputSize)); + + _inputSize /= _engineCfg.max_batch_size; + _outputSize /= _engineCfg.max_batch_size; + + } + + bool InferenceEngine::doInference(const int inference_batch_size, std::function preprocessing) { + assert(inference_batch_size <= _engineCfg.max_batch_size); + preprocessing(_data); + CHECK(cudaSetDevice(_engineCfg.device_id)); + CHECK(cudaMemcpyAsync(_buffers[_inputIndex], _data, inference_batch_size * _inputSize, cudaMemcpyHostToDevice, *_streamptr)); + auto status = _context->enqueue(inference_batch_size, _buffers, *_streamptr, nullptr); + CHECK(cudaMemcpyAsync(_prob, _buffers[_outputIndex], inference_batch_size * _outputSize, cudaMemcpyDeviceToHost, *_streamptr)); + CHECK(cudaStreamSynchronize(*_streamptr)); + return status; + } + + InferenceEngine::InferenceEngine(InferenceEngine &&other) noexcept: + _engineCfg(other._engineCfg) + , _data(other._data) + , _prob(other._prob) + , _inputIndex(other._inputIndex) + , _outputIndex(other._outputIndex) + , _inputSize(other._inputSize) + , _outputSize(other._outputSize) + , _runtime(std::move(other._runtime)) + , _engine(std::move(other._engine)) + , _context(std::move(other._context)) + , _streamptr(other._streamptr) { + + _buffers[0] = other._buffers[0]; + _buffers[1] = other._buffers[1]; + other._streamptr.reset(); + other._data = nullptr; + other._prob = nullptr; + other._buffers[0] = nullptr; + other._buffers[1] = nullptr; + } + + InferenceEngine::~InferenceEngine() { + CHECK(cudaFreeHost(_data)); + CHECK(cudaFreeHost(_prob)); + CHECK(cudaFree(_buffers[_inputIndex])); + CHECK(cudaFree(_buffers[_outputIndex])); + } +} \ No newline at end of file diff --git a/ibnnet/InferenceEngine.h b/ibnnet/InferenceEngine.h new file mode 100755 index 00000000..dd71223b --- /dev/null +++ b/ibnnet/InferenceEngine.h @@ -0,0 +1,76 @@ +/************************************************************************** + * Handle memory pre-alloc + * both on host(pinned memory, allow CUDA DMA) & device +*************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +#include "utils.h" +#include "holder.h" +#include "logging.h" +#include "NvInfer.h" +#include "cuda_runtime_api.h" +static Logger gLogger; + +namespace trt { + + struct EngineConfig { + const char* input_name; + const char* output_name; + std::shared_ptr trtModelStream; + int max_batch_size; /* create engine */ + int input_h; + int input_w; + int output_size; + int stream_size; + int device_id; + }; + + class InferenceEngine { + + public: + InferenceEngine(const EngineConfig &enginecfg); + InferenceEngine(InferenceEngine &&other) noexcept; + ~InferenceEngine(); + + InferenceEngine(const InferenceEngine &) = delete; + InferenceEngine& operator=(const InferenceEngine &) = delete; + InferenceEngine& operator=(InferenceEngine && other) = delete; + + bool doInference(const int inference_batch_size, std::function preprocessing); + float* getOutput() { return _prob; } + std::thread::id getThreadID() { return std::this_thread::get_id(); } + + private: + EngineConfig _engineCfg; + float* _data{nullptr}; + float* _prob{nullptr}; + + // Pointers to input and output device buffers to pass to engine. + // Engine requires exactly IEngine::getNbBindings() number of buffers. + void* _buffers[2]; + + // In order to bind the buffers, we need to know the names of the input and output tensors. + // Note that indices are guaranteed to be less than IEngine::getNbBindings() + int _inputIndex; + int _outputIndex; + + int _inputSize; + int _outputSize; + + static constexpr std::size_t _depth{sizeof(float)}; + + TensorRTHolder _runtime{nullptr}; + TensorRTHolder _engine{nullptr}; + TensorRTHolder _context{nullptr}; + std::shared_ptr _streamptr; + }; + +} + diff --git a/ibnnet/README.md b/ibnnet/README.md new file mode 100644 index 00000000..206b39ca --- /dev/null +++ b/ibnnet/README.md @@ -0,0 +1,46 @@ +# IBN-Net + +An implementation of IBN-Net, proposed in ["Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net"](https://arxiv.org/abs/1807.09441), ECCV2018 by Xingang Pan, Ping Luo, Jianping Shi, Xiaoou Tang. + +For the Pytorch implementation, you can refer to [IBN-Net](https://github.com/XingangPan/IBN-Net) + +## Features +- InstanceNorm2d +- bottleneck_ibn +- Resnet50-IBNA +- Resnet50-IBNB +- Multi-thread inference + +## How to Run + +* 1. generate .wts + + // for ibn-a + ``` + python gen_wts.py a + ``` + a file 'resnet50-ibna.wts' will be generated. + + // for ibn-b + ``` + python gen_wts.py b + ``` + a file 'resnet50-ibnb.wts' will be generated. +* 2. cmake and make + + ``` + mkdir build + cd build + cmake .. + make + ``` +* 3. build engine and run classification + + // put resnet50-ibna.wts/resnet50-ibnb.wts into tensorrtx/ibnnet + + // go to tensorrtx/ibnnet + ``` + ./ibnnet -s // serialize model to plan file + ./ibnnet -d // deserialize plan file and run inference + ``` + \ No newline at end of file diff --git a/ibnnet/gen_wts.py b/ibnnet/gen_wts.py new file mode 100755 index 00000000..c77d3f9a --- /dev/null +++ b/ibnnet/gen_wts.py @@ -0,0 +1,30 @@ +import torch +import os +import sys +import struct + + +assert sys.argv[1] == "a" or sys.argv[1] == "b" +model_name = "resnet50_ibn_" + sys.argv[1] + +net = torch.hub.load('XingangPan/IBN-Net', model_name, pretrained=True).to('cuda:0').eval() + +#verify +#input = torch.ones(1, 3, 224, 224).to('cuda:0') +#pixel_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1).to('cuda:0') +#pixel_std = torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1).to('cuda:0') +#input.sub_(pixel_mean).div_(pixel_std) +#out = net(input) +#print(out) + +f = open(model_name + ".wts", 'w') +f.write("{}\n".format(len(net.state_dict().keys()))) +for k,v in net.state_dict().items(): + vr = v.reshape(-1).cpu().numpy() + f.write("{} {}".format(k, len(vr))) + for vv in vr: + f.write(" ") + f.write(struct.pack(">f", float(vv)).hex()) + f.write("\n") + + diff --git a/ibnnet/holder.h b/ibnnet/holder.h new file mode 100755 index 00000000..43334c3a --- /dev/null +++ b/ibnnet/holder.h @@ -0,0 +1,41 @@ +#pragma once + +template +class TensorRTHolder { + T* holder; +public: + explicit TensorRTHolder(T* holder_) : holder(holder_) {} + ~TensorRTHolder() { + if (holder) + holder->destroy(); + } + TensorRTHolder(const TensorRTHolder&) = delete; + TensorRTHolder& operator=(const TensorRTHolder&) = delete; + TensorRTHolder(TensorRTHolder && rhs) noexcept{ + holder = rhs.holder; + rhs.holder = nullptr; + } + TensorRTHolder& operator=(TensorRTHolder&& rhs) noexcept { + if (this == &rhs) { + return *this; + } + if (holder) holder->destroy(); + holder = rhs.holder; + rhs.holder = nullptr; + return *this; + } + T* operator->() { + return holder; + } + T* get() { return holder; } + explicit operator bool() { return holder != nullptr; } + T& operator*() noexcept { return *holder; } +}; + +template +TensorRTHolder make_holder(T* holder) { + return TensorRTHolder(holder); +} + +template +using TensorRTNonHolder = T*; \ No newline at end of file diff --git a/ibnnet/ibnnet.cpp b/ibnnet/ibnnet.cpp new file mode 100644 index 00000000..b0a2a766 --- /dev/null +++ b/ibnnet/ibnnet.cpp @@ -0,0 +1,197 @@ +#include "ibnnet.h" + +//#define USE_FP16 + +namespace trt { + + IBNNet::IBNNet(trt::EngineConfig &enginecfg, const IBN ibn) : _engineCfg(enginecfg) { + switch(ibn) { + case IBN::A: + _ibn = "a"; + break; + case IBN::B: + _ibn = "b"; + break; + case IBN::NONE: + default: + _ibn = ""; + break; + } + } + + // create the engine using only the API and not any parser. + ICudaEngine *IBNNet::createEngine(IBuilder* builder, IBuilderConfig* config) { + // resnet50-ibna, resnet50-ibnb, resnet50 + assert(_ibn == "a" or _ibn == "b" or _ibn == ""); + INetworkDefinition* network = builder->createNetworkV2(0U); + + // Create input tensor of shape { 3, INPUT_H, INPUT_W } with name INPUT_BLOB_NAME + ITensor* data = network->addInput(_engineCfg.input_name, _dt, Dims3{3, _engineCfg.input_h, _engineCfg.input_w}); + assert(data); + + std::string path; + if(_ibn == "") { + path = "../resnet50.wts"; + } else { + path = "../resnet50-ibn" + _ibn + ".wts"; + } + + std::map weightMap = loadWeights(path); + Weights emptywts{DataType::kFLOAT, nullptr, 0}; + + std::map> ibn_layers{ + { "a", {"a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "", "", ""}}, + { "b", {"", "", "b", "", "", "","b", "", "", "", "", "", "", "", "", "",}}, + { "", {16, ""}}}; + + const float mean[3] = {0.485, 0.456, 0.406}; // rgb + const float std[3] = {0.229, 0.224, 0.225}; + ITensor* pre_input = MeanStd(network, weightMap, data, "", mean, std, false); + + IConvolutionLayer* conv1 = network->addConvolutionNd(*pre_input, 64, DimsHW{7, 7}, weightMap["conv1.weight"], emptywts); + assert(conv1); + conv1->setStrideNd(DimsHW{2, 2}); + conv1->setPaddingNd(DimsHW{3, 3}); + + IActivationLayer* relu1{nullptr}; + if (_ibn == "b") { + IScaleLayer* bn1 = addInstanceNorm2d(network, weightMap, *conv1->getOutput(0), "bn1", 1e-5); + relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU); + } else { + IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), "bn1", 1e-5); + relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU); + } + assert(relu1); + + // Add max pooling layer with stride of 2x2 and kernel size of 2x2. + IPoolingLayer* pool1 = network->addPoolingNd(*relu1->getOutput(0), PoolingType::kMAX, DimsHW{3, 3}); + assert(pool1); + pool1->setStrideNd(DimsHW{2, 2}); + pool1->setPaddingNd(DimsHW{1, 1}); + + IActivationLayer* x = bottleneck_ibn(network, weightMap, *pool1->getOutput(0), 64, 64, 1, "layer1.0.", ibn_layers[_ibn][0]); + x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 256, 64, 1, "layer1.1.", ibn_layers[_ibn][1]); + x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 256, 64, 1, "layer1.2.", ibn_layers[_ibn][2]); + + x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 256, 128, 2, "layer2.0.", ibn_layers[_ibn][3]); + x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 512, 128, 1, "layer2.1.", ibn_layers[_ibn][4]); + x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 512, 128, 1, "layer2.2.", ibn_layers[_ibn][5]); + x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 512, 128, 1, "layer2.3.", ibn_layers[_ibn][6]); + + x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 512, 256, 2, "layer3.0.", ibn_layers[_ibn][7]); + x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 1024, 256, 1, "layer3.1.", ibn_layers[_ibn][8]); + x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 1024, 256, 1, "layer3.2.", ibn_layers[_ibn][9]); + x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 1024, 256, 1, "layer3.3.", ibn_layers[_ibn][10]); + x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 1024, 256, 1, "layer3.4.", ibn_layers[_ibn][11]); + x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 1024, 256, 1, "layer3.5.", ibn_layers[_ibn][12]); + + x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 1024, 512, 2, "layer4.0.", ibn_layers[_ibn][13]); + x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 2048, 512, 1, "layer4.1.", ibn_layers[_ibn][14]); + x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 2048, 512, 1, "layer4.2.", ibn_layers[_ibn][15]); + + IPoolingLayer* pool2 = network->addPoolingNd(*x->getOutput(0), PoolingType::kAVERAGE, DimsHW{7, 7}); + assert(pool2); + pool2->setStrideNd(DimsHW{1, 1}); + + IFullyConnectedLayer* fc1 = network->addFullyConnected(*pool2->getOutput(0), 1000, weightMap["fc.weight"], weightMap["fc.bias"]); + assert(fc1); + + fc1->getOutput(0)->setName(_engineCfg.output_name); + std::cout << "set name out" << std::endl; + network->markOutput(*fc1->getOutput(0)); + + // Build engine + builder->setMaxBatchSize(_engineCfg.max_batch_size); + config->setMaxWorkspaceSize(1 << 20); + + #ifdef USE_FP16 + config->setFlag(BuilderFlag::kFP16); + #endif + ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config); + std::cout << "build out" << std::endl; + + // Don't need the network any more + network->destroy(); + + // Release host memory + for (auto& mem : weightMap) { + free((void*) (mem.second.values)); + } + + return engine; + } + + bool IBNNet::serializeEngine() { + // Create builder + auto builder = make_holder(createInferBuilder(gLogger)); + auto config = make_holder(builder->createBuilderConfig()); + // Create model to populate the network, then set the outputs and create an engine + ICudaEngine *engine = createEngine(builder.get(), config.get()); + assert(engine); + + // Serialize the engine + TensorRTHolder modelStream = make_holder(engine->serialize()); + assert(modelStream); + + std::ofstream p("./ibnnet.engine", std::ios::binary | std::ios::out); + if (!p) { + std::cerr << "could not open plan output file" << std::endl; + return false; + } + p.write(reinterpret_cast(modelStream->data()), modelStream->size()); + + return true; + } + + bool IBNNet::deserializeEngine() { + std::ifstream file("./ibnnet.engine", std::ios::binary | std::ios::in); + if (file.good()) { + file.seekg(0, file.end); + _engineCfg.stream_size = file.tellg(); + file.seekg(0, file.beg); + _engineCfg.trtModelStream = std::shared_ptr( new char[_engineCfg.stream_size], []( char* ptr ){ delete [] ptr; } ); + assert(_engineCfg.trtModelStream.get()); + file.read(_engineCfg.trtModelStream.get(), _engineCfg.stream_size); + file.close(); + + _inferEngine = make_unique(_engineCfg); + return true; + } + return false; + } + + void IBNNet::preprocessing(const cv::Mat& img, float* const data, const std::size_t stride) { + for (std::size_t i = 0; i < stride; ++i) { + data[i] = img.at(i)[2] / 255.0; + data[i + stride] = img.at(i)[1] / 255.0; + data[i + (stride<<1)] = img.at(i)[0] / 255.0; + } + } + + bool IBNNet::inference(std::vector &input) { + if(_inferEngine != nullptr) { + const std::size_t stride = _engineCfg.input_w * _engineCfg.input_h; + return _inferEngine.get()->doInference(input.size(), + [&](float* data) { + for(const auto &img : input) { + preprocessing(img, data, stride); + data += 3 * stride; + } + } + ); + } else { + return false; + } + } + + float* IBNNet::getOutput() { + if(_inferEngine != nullptr) + return _inferEngine.get()->getOutput(); + return nullptr; + } + + int IBNNet::getDeviceID() { + return _engineCfg.device_id; + } + +} \ No newline at end of file diff --git a/ibnnet/ibnnet.h b/ibnnet/ibnnet.h new file mode 100644 index 00000000..c75537d4 --- /dev/null +++ b/ibnnet/ibnnet.h @@ -0,0 +1,45 @@ +#pragma once + +#include "utils.h" +#include "holder.h" +#include "layers.h" +#include "InferenceEngine.h" +#include +#include +#include +#include +extern Logger gLogger; +using namespace trtxapi; + +namespace trt { + + enum IBN { + A, // resnet50-ibna, + B, // resnet50-ibnb, + NONE // resnet50 + }; + + class IBNNet { + public: + IBNNet(trt::EngineConfig &enginecfg, const IBN ibn); + ~IBNNet() {}; + + bool serializeEngine(); /* create & serializeEngine */ + bool deserializeEngine(); + bool inference(std::vector &input); /* support batch inference */ + + float* getOutput(); + int getDeviceID(); /* cuda deviceid */ + + private: + ICudaEngine *createEngine(IBuilder *builder, IBuilderConfig *config); + void preprocessing(const cv::Mat& img, float* const data, const std::size_t stride); + + private: + trt::EngineConfig _engineCfg; + std::unique_ptr _inferEngine{nullptr}; + std::string _ibn; + DataType _dt{DataType::kFLOAT}; + }; + +} \ No newline at end of file diff --git a/ibnnet/layers.cpp b/ibnnet/layers.cpp new file mode 100644 index 00000000..0076cf9e --- /dev/null +++ b/ibnnet/layers.cpp @@ -0,0 +1,210 @@ +#include "layers.h" + +namespace trtxapi { + + ITensor* MeanStd(INetworkDefinition *network, std::map& weightMap, ITensor* input, const std::string lname, const float* mean, const float* std, const bool div255) { + if(div255) { + Weights Div_225{ DataType::kFLOAT, nullptr, 3 }; + float *wgt = reinterpret_cast(malloc(sizeof(float) * 3)); + std::fill_n(wgt, 3, 255.0f); + Div_225.values = wgt; + weightMap[lname + ".div"] = Div_225; + IConstantLayer* d = network->addConstant(Dims3{ 3, 1, 1 }, Div_225); + input = network->addElementWise(*input, *d->getOutput(0), ElementWiseOperation::kDIV)->getOutput(0); + } + Weights Mean{ DataType::kFLOAT, nullptr, 3 }; + Mean.values = mean; + IConstantLayer* m = network->addConstant(Dims3{ 3, 1, 1 }, Mean); + IElementWiseLayer* sub_mean = network->addElementWise(*input, *m->getOutput(0), ElementWiseOperation::kSUB); + if (std != nullptr) { + Weights Std{ DataType::kFLOAT, nullptr, 3 }; + Std.values = std; + IConstantLayer* s = network->addConstant(Dims3{ 3, 1, 1 }, Std); + IElementWiseLayer* std_mean = network->addElementWise(*sub_mean->getOutput(0), *s->getOutput(0), ElementWiseOperation::kDIV); + return std_mean->getOutput(0); + } else { + return sub_mean->getOutput(0); + } + } + + IScaleLayer* addBatchNorm2d(INetworkDefinition *network, std::map& weightMap, ITensor& input, const std::string lname, const float eps) { + float *gamma = (float*)weightMap[lname + ".weight"].values; + float *beta = (float*)weightMap[lname + ".bias"].values; + float *mean = (float*)weightMap[lname + ".running_mean"].values; + float *var = (float*)weightMap[lname + ".running_var"].values; + int len = weightMap[lname + ".running_var"].count; + + float *scval = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + scval[i] = gamma[i] / sqrt(var[i] + eps); + } + Weights wscale{DataType::kFLOAT, scval, len}; + + float *shval = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + shval[i] = beta[i] - mean[i] * gamma[i] / sqrt(var[i] + eps); + } + Weights wshift{DataType::kFLOAT, shval, len}; + + float *pval = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + pval[i] = 1.0; + } + Weights wpower{DataType::kFLOAT, pval, len}; + + weightMap[lname + ".scale"] = wscale; + weightMap[lname + ".shift"] = wshift; + weightMap[lname + ".power"] = wpower; + IScaleLayer* scale_1 = network->addScale(input, ScaleMode::kCHANNEL, wshift, wscale, wpower); + assert(scale_1); + return scale_1; + } + + IScaleLayer* addInstanceNorm2d(INetworkDefinition *network, std::map& weightMap, ITensor& input, const std::string lname, const float eps) { + + int len = weightMap[lname + ".weight"].count; + + IReduceLayer* reduce1 = network->addReduce(input, + ReduceOperation::kAVG, + 6, + true); + assert(reduce1); + + IElementWiseLayer* ew1 = network->addElementWise(input, + *reduce1->getOutput(0), + ElementWiseOperation::kSUB); + assert(ew1); + + const static float pval1[3]{0.0, 1.0, 2.0}; + Weights wshift1{DataType::kFLOAT, pval1, 1}; + Weights wscale1{DataType::kFLOAT, pval1+1, 1}; + Weights wpower1{DataType::kFLOAT, pval1+2, 1}; + + IScaleLayer* scale1 = network->addScale( + *ew1->getOutput(0), + ScaleMode::kUNIFORM, + wshift1, + wscale1, + wpower1); + assert(scale1); + + IReduceLayer* reduce2 = network->addReduce( + *scale1->getOutput(0), + ReduceOperation::kAVG, + 6, + true); + assert(reduce2); + + const static float pval2[3]{eps, 1.0, 0.5}; + Weights wshift2{DataType::kFLOAT, pval2, 1}; + Weights wscale2{DataType::kFLOAT, pval2+1, 1}; + Weights wpower2{DataType::kFLOAT, pval2+2, 1}; + + IScaleLayer* scale2 = network->addScale( + *reduce2->getOutput(0), + ScaleMode::kUNIFORM, + wshift2, + wscale2, + wpower2); + assert(scale2); + + IElementWiseLayer* ew2 = network->addElementWise(*ew1->getOutput(0), + *scale2->getOutput(0), + ElementWiseOperation::kDIV); + assert(ew2); + + float* pval3 = reinterpret_cast(malloc(sizeof(float) * len)); + std::fill_n(pval3, len, 1.0); + Weights wpower3{DataType::kFLOAT, pval3, len}; + weightMap[lname + ".power3"] = wpower3; + + IScaleLayer* scale3 = network->addScale( + *ew2->getOutput(0), + ScaleMode::kCHANNEL, + weightMap[lname + ".bias"], + weightMap[lname + ".weight"], + wpower3); + assert(scale3); + return scale3; + } + + IConcatenationLayer* addIBN(INetworkDefinition *network, std::map& weightMap, ITensor& input, const std::string lname) { + Dims spliteDims = input.getDimensions(); + ISliceLayer *split1 = network->addSlice(input, + Dims3{0, 0, 0}, + Dims3{spliteDims.d[0]/2, spliteDims.d[1], spliteDims.d[2]}, + Dims3{1, 1, 1}); + assert(split1); + + ISliceLayer *split2 = network->addSlice(input, + Dims3{spliteDims.d[0]/2, 0, 0}, + Dims3{spliteDims.d[0]/2, spliteDims.d[1], spliteDims.d[2]}, + Dims3{1, 1, 1}); + assert(split2); + + auto in1 = addInstanceNorm2d(network, weightMap, *split1->getOutput(0), lname + "IN", 1e-5); + auto bn1 = addBatchNorm2d(network, weightMap, *split2->getOutput(0), lname + "BN", 1e-5); + + ITensor* tensor1[] = {in1->getOutput(0), bn1->getOutput(0)}; + auto cat1 = network->addConcatenation(tensor1, 2); + assert(cat1); + return cat1; + } + + IActivationLayer* bottleneck_ibn(INetworkDefinition *network, std::map& weightMap, ITensor& input, const int inch, const int outch, const int stride, const std::string lname, const std::string ibn) { + Weights emptywts{DataType::kFLOAT, nullptr, 0}; + + IConvolutionLayer* conv1 = network->addConvolutionNd(input, outch, DimsHW{1, 1}, weightMap[lname + "conv1.weight"], emptywts); + assert(conv1); + + IActivationLayer* relu1{nullptr}; + if (ibn == "a") { + IConcatenationLayer* bn1 = addIBN(network, weightMap, *conv1->getOutput(0), lname + "bn1."); + relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU); + assert(relu1); + } else { + IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname + "bn1", 1e-5); + relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU); + assert(relu1); + } + + IConvolutionLayer* conv2 = network->addConvolutionNd(*relu1->getOutput(0), outch, DimsHW{3, 3}, weightMap[lname + "conv2.weight"], emptywts); + assert(conv2); + conv2->setStrideNd(DimsHW{stride, stride}); + conv2->setPaddingNd(DimsHW{1, 1}); + + IScaleLayer* bn2 = addBatchNorm2d(network, weightMap, *conv2->getOutput(0), lname + "bn2", 1e-5); + + IActivationLayer* relu2 = network->addActivation(*bn2->getOutput(0), ActivationType::kRELU); + assert(relu2); + + IConvolutionLayer* conv3 = network->addConvolutionNd(*relu2->getOutput(0), outch * 4, DimsHW{1, 1}, weightMap[lname + "conv3.weight"], emptywts); + assert(conv3); + + IScaleLayer* bn3 = addBatchNorm2d(network, weightMap, *conv3->getOutput(0), lname + "bn3", 1e-5); + + IElementWiseLayer* ew1; + if (stride != 1 || inch != outch * 4) { + IConvolutionLayer* conv4 = network->addConvolutionNd(input, outch * 4, DimsHW{1, 1}, weightMap[lname + "downsample.0.weight"], emptywts); + assert(conv4); + conv4->setStrideNd(DimsHW{stride, stride}); + + IScaleLayer* bn4 = addBatchNorm2d(network, weightMap, *conv4->getOutput(0), lname + "downsample.1", 1e-5); + ew1 = network->addElementWise(*bn4->getOutput(0), *bn3->getOutput(0), ElementWiseOperation::kSUM); + } else { + ew1 = network->addElementWise(input, *bn3->getOutput(0), ElementWiseOperation::kSUM); + } + + IActivationLayer* relu3{nullptr}; + if (ibn == "b") { + IScaleLayer* in1 = addInstanceNorm2d(network, weightMap, *ew1->getOutput(0), lname + "IN", 1e-5); + relu3 = network->addActivation(*in1->getOutput(0), ActivationType::kRELU); + } else { + relu3 = network->addActivation(*ew1->getOutput(0), ActivationType::kRELU); + } + + assert(relu3); + return relu3; + } + +} \ No newline at end of file diff --git a/ibnnet/layers.h b/ibnnet/layers.h new file mode 100644 index 00000000..e37f88b5 --- /dev/null +++ b/ibnnet/layers.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include +#include "NvInfer.h" +#include "cuda_runtime_api.h" +using namespace nvinfer1; + +namespace trtxapi { + + ITensor* MeanStd(INetworkDefinition *network, + std::map& weightMap, + ITensor* input, + const std::string lname, + const float* mean, + const float* std, + const bool div255); + + IScaleLayer* addBatchNorm2d(INetworkDefinition *network, + std::map& weightMap, + ITensor& input, + const std::string lname, + const float eps); + + IScaleLayer* addInstanceNorm2d(INetworkDefinition *network, + std::map& weightMap, + ITensor& input, + const std::string lname, + const float eps); + + IConcatenationLayer* addIBN(INetworkDefinition *network, + std::map& weightMap, + ITensor& input, + const std::string lname); + + IActivationLayer* bottleneck_ibn(INetworkDefinition *network, + std::map& weightMap, + ITensor& input, + const int inch, + const int outch, + const int stride, + const std::string lname, + const std::string ibn); + +} \ No newline at end of file diff --git a/ibnnet/logging.h b/ibnnet/logging.h new file mode 100644 index 00000000..602b69fb --- /dev/null +++ b/ibnnet/logging.h @@ -0,0 +1,503 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TENSORRT_LOGGING_H +#define TENSORRT_LOGGING_H + +#include "NvInferRuntimeCommon.h" +#include +#include +#include +#include +#include +#include +#include + +using Severity = nvinfer1::ILogger::Severity; + +class LogStreamConsumerBuffer : public std::stringbuf +{ +public: + LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) + : mOutput(stream) + , mPrefix(prefix) + , mShouldLog(shouldLog) + { + } + + LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) + : mOutput(other.mOutput) + { + } + + ~LogStreamConsumerBuffer() + { + // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence + // std::streambuf::pptr() gives a pointer to the current position of the output sequence + // if the pointer to the beginning is not equal to the pointer to the current position, + // call putOutput() to log the output to the stream + if (pbase() != pptr()) + { + putOutput(); + } + } + + // synchronizes the stream buffer and returns 0 on success + // synchronizing the stream buffer consists of inserting the buffer contents into the stream, + // resetting the buffer and flushing the stream + virtual int sync() + { + putOutput(); + return 0; + } + + void putOutput() + { + if (mShouldLog) + { + // prepend timestamp + std::time_t timestamp = std::time(nullptr); + tm* tm_local = std::localtime(×tamp); + std::cout << "["; + std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; + std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; + // std::stringbuf::str() gets the string contents of the buffer + // insert the buffer contents pre-appended by the appropriate prefix into the stream + mOutput << mPrefix << str(); + // set the buffer to empty + str(""); + // flush the stream + mOutput.flush(); + } + } + + void setShouldLog(bool shouldLog) + { + mShouldLog = shouldLog; + } + +private: + std::ostream& mOutput; + std::string mPrefix; + bool mShouldLog; +}; + +//! +//! \class LogStreamConsumerBase +//! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer +//! +class LogStreamConsumerBase +{ +public: + LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) + : mBuffer(stream, prefix, shouldLog) + { + } + +protected: + LogStreamConsumerBuffer mBuffer; +}; + +//! +//! \class LogStreamConsumer +//! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. +//! Order of base classes is LogStreamConsumerBase and then std::ostream. +//! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field +//! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. +//! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. +//! Please do not change the order of the parent classes. +//! +class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream +{ +public: + //! \brief Creates a LogStreamConsumer which logs messages with level severity. + //! Reportable severity determines if the messages are severe enough to be logged. + LogStreamConsumer(Severity reportableSeverity, Severity severity) + : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) + , std::ostream(&mBuffer) // links the stream buffer with the stream + , mShouldLog(severity <= reportableSeverity) + , mSeverity(severity) + { + } + + LogStreamConsumer(LogStreamConsumer&& other) + : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) + , std::ostream(&mBuffer) // links the stream buffer with the stream + , mShouldLog(other.mShouldLog) + , mSeverity(other.mSeverity) + { + } + + void setReportableSeverity(Severity reportableSeverity) + { + mShouldLog = mSeverity <= reportableSeverity; + mBuffer.setShouldLog(mShouldLog); + } + +private: + static std::ostream& severityOstream(Severity severity) + { + return severity >= Severity::kINFO ? std::cout : std::cerr; + } + + static std::string severityPrefix(Severity severity) + { + switch (severity) + { + case Severity::kINTERNAL_ERROR: return "[F] "; + case Severity::kERROR: return "[E] "; + case Severity::kWARNING: return "[W] "; + case Severity::kINFO: return "[I] "; + case Severity::kVERBOSE: return "[V] "; + default: assert(0); return ""; + } + } + + bool mShouldLog; + Severity mSeverity; +}; + +//! \class Logger +//! +//! \brief Class which manages logging of TensorRT tools and samples +//! +//! \details This class provides a common interface for TensorRT tools and samples to log information to the console, +//! and supports logging two types of messages: +//! +//! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) +//! - Test pass/fail messages +//! +//! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is +//! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. +//! +//! In the future, this class could be extended to support dumping test results to a file in some standard format +//! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). +//! +//! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger +//! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT +//! library and messages coming from the sample. +//! +//! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the +//! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger +//! object. + +class Logger : public nvinfer1::ILogger +{ +public: + Logger(Severity severity = Severity::kWARNING) + : mReportableSeverity(severity) + { + } + + //! + //! \enum TestResult + //! \brief Represents the state of a given test + //! + enum class TestResult + { + kRUNNING, //!< The test is running + kPASSED, //!< The test passed + kFAILED, //!< The test failed + kWAIVED //!< The test was waived + }; + + //! + //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger + //! \return The nvinfer1::ILogger associated with this Logger + //! + //! TODO Once all samples are updated to use this method to register the logger with TensorRT, + //! we can eliminate the inheritance of Logger from ILogger + //! + nvinfer1::ILogger& getTRTLogger() + { + return *this; + } + + //! + //! \brief Implementation of the nvinfer1::ILogger::log() virtual method + //! + //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the + //! inheritance from nvinfer1::ILogger + //! + void log(Severity severity, const char* msg) override + { + LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; + } + + //! + //! \brief Method for controlling the verbosity of logging output + //! + //! \param severity The logger will only emit messages that have severity of this level or higher. + //! + void setReportableSeverity(Severity severity) + { + mReportableSeverity = severity; + } + + //! + //! \brief Opaque handle that holds logging information for a particular test + //! + //! This object is an opaque handle to information used by the Logger to print test results. + //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used + //! with Logger::reportTest{Start,End}(). + //! + class TestAtom + { + public: + TestAtom(TestAtom&&) = default; + + private: + friend class Logger; + + TestAtom(bool started, const std::string& name, const std::string& cmdline) + : mStarted(started) + , mName(name) + , mCmdline(cmdline) + { + } + + bool mStarted; + std::string mName; + std::string mCmdline; + }; + + //! + //! \brief Define a test for logging + //! + //! \param[in] name The name of the test. This should be a string starting with + //! "TensorRT" and containing dot-separated strings containing + //! the characters [A-Za-z0-9_]. + //! For example, "TensorRT.sample_googlenet" + //! \param[in] cmdline The command line used to reproduce the test + // + //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). + //! + static TestAtom defineTest(const std::string& name, const std::string& cmdline) + { + return TestAtom(false, name, cmdline); + } + + //! + //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments + //! as input + //! + //! \param[in] name The name of the test + //! \param[in] argc The number of command-line arguments + //! \param[in] argv The array of command-line arguments (given as C strings) + //! + //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). + static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) + { + auto cmdline = genCmdlineString(argc, argv); + return defineTest(name, cmdline); + } + + //! + //! \brief Report that a test has started. + //! + //! \pre reportTestStart() has not been called yet for the given testAtom + //! + //! \param[in] testAtom The handle to the test that has started + //! + static void reportTestStart(TestAtom& testAtom) + { + reportTestResult(testAtom, TestResult::kRUNNING); + assert(!testAtom.mStarted); + testAtom.mStarted = true; + } + + //! + //! \brief Report that a test has ended. + //! + //! \pre reportTestStart() has been called for the given testAtom + //! + //! \param[in] testAtom The handle to the test that has ended + //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, + //! TestResult::kFAILED, TestResult::kWAIVED + //! + static void reportTestEnd(const TestAtom& testAtom, TestResult result) + { + assert(result != TestResult::kRUNNING); + assert(testAtom.mStarted); + reportTestResult(testAtom, result); + } + + static int reportPass(const TestAtom& testAtom) + { + reportTestEnd(testAtom, TestResult::kPASSED); + return EXIT_SUCCESS; + } + + static int reportFail(const TestAtom& testAtom) + { + reportTestEnd(testAtom, TestResult::kFAILED); + return EXIT_FAILURE; + } + + static int reportWaive(const TestAtom& testAtom) + { + reportTestEnd(testAtom, TestResult::kWAIVED); + return EXIT_SUCCESS; + } + + static int reportTest(const TestAtom& testAtom, bool pass) + { + return pass ? reportPass(testAtom) : reportFail(testAtom); + } + + Severity getReportableSeverity() const + { + return mReportableSeverity; + } + +private: + //! + //! \brief returns an appropriate string for prefixing a log message with the given severity + //! + static const char* severityPrefix(Severity severity) + { + switch (severity) + { + case Severity::kINTERNAL_ERROR: return "[F] "; + case Severity::kERROR: return "[E] "; + case Severity::kWARNING: return "[W] "; + case Severity::kINFO: return "[I] "; + case Severity::kVERBOSE: return "[V] "; + default: assert(0); return ""; + } + } + + //! + //! \brief returns an appropriate string for prefixing a test result message with the given result + //! + static const char* testResultString(TestResult result) + { + switch (result) + { + case TestResult::kRUNNING: return "RUNNING"; + case TestResult::kPASSED: return "PASSED"; + case TestResult::kFAILED: return "FAILED"; + case TestResult::kWAIVED: return "WAIVED"; + default: assert(0); return ""; + } + } + + //! + //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity + //! + static std::ostream& severityOstream(Severity severity) + { + return severity >= Severity::kINFO ? std::cout : std::cerr; + } + + //! + //! \brief method that implements logging test results + //! + static void reportTestResult(const TestAtom& testAtom, TestResult result) + { + severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " + << testAtom.mCmdline << std::endl; + } + + //! + //! \brief generate a command line string from the given (argc, argv) values + //! + static std::string genCmdlineString(int argc, char const* const* argv) + { + std::stringstream ss; + for (int i = 0; i < argc; i++) + { + if (i > 0) + ss << " "; + ss << argv[i]; + } + return ss.str(); + } + + Severity mReportableSeverity; +}; + +namespace +{ + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE +//! +//! Example usage: +//! +//! LOG_VERBOSE(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO +//! +//! Example usage: +//! +//! LOG_INFO(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_INFO(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING +//! +//! Example usage: +//! +//! LOG_WARN(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_WARN(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR +//! +//! Example usage: +//! +//! LOG_ERROR(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_ERROR(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR +// ("fatal" severity) +//! +//! Example usage: +//! +//! LOG_FATAL(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_FATAL(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); +} + +} // anonymous namespace + +#endif // TENSORRT_LOGGING_H diff --git a/ibnnet/main.cpp b/ibnnet/main.cpp new file mode 100644 index 00000000..1b3e368c --- /dev/null +++ b/ibnnet/main.cpp @@ -0,0 +1,98 @@ +#include +#include +#include +#include "ibnnet.h" +#include "InferenceEngine.h" + +// stuff we know about the network and the input/output blobs +static const int MAX_BATCH_SIZE = 4; +static const int INPUT_H = 224; +static const int INPUT_W = 224; +static const int OUTPUT_SIZE = 1000; +static const int DEVICE_ID = 0; +const char* INPUT_BLOB_NAME = "data"; +const char* OUTPUT_BLOB_NAME = "prob"; +extern Logger gLogger; + +void run_infer(std::shared_ptr model) { + + CHECK(cudaSetDevice(model->getDeviceID())); + + if(!model->deserializeEngine()) { + std::cout << "DeserializeEngine Failed." << std::endl; + return; + } + + /* support batch input data */ + std::vector input; + input.emplace_back( cv::Mat(INPUT_H, INPUT_W, CV_8UC3, cv::Scalar(255,255,255)) ) ; + + /* run inference */ + model->inference(input); + + /* get output data from cudaMalloc */ + float* prob = model->getOutput(); + + /* print output */ + std::cout << "\nOutput from thread_id: " << std::this_thread::get_id() << std::endl; + if( prob != nullptr ) { + for (size_t batch_idx = 0; batch_idx < input.size(); ++batch_idx) { + for (int p = 0; p < OUTPUT_SIZE; ++p) { + std::cout<< prob[batch_idx+p] << " "; + if ((p+1) % 10 == 0) { + std::cout << std::endl; + } + } + } + } +} + +int main(int argc, char** argv) { + + trt::EngineConfig engineCfg { + INPUT_BLOB_NAME, + OUTPUT_BLOB_NAME, + nullptr, + MAX_BATCH_SIZE, + INPUT_H, + INPUT_W, + OUTPUT_SIZE, + 0, + DEVICE_ID}; + + if (argc == 2 && std::string(argv[1]) == "-s") { + std::cout << "Serializling Engine" << std::endl; + trt::IBNNet ibnnet{engineCfg, trt::IBN::A}; + ibnnet.serializeEngine(); + return 0; + } else if (argc == 2 && std::string(argv[1]) == "-d") { + + /* + * Support multi thread inference (mthreads>1) + * Each thread holds their own CudaEngine + * They can run on different cuda device through trt::EngineConfig setting + */ + int mthreads = 1; + std::vector workers; + std::vector> models; + + for(int i = 0; i < mthreads; ++i) { + models.emplace_back( std::make_shared(engineCfg, trt::IBN::A) ); // For IBNB: trt::IBN::B + } + + for(int i = 0; i < mthreads; ++i) { + workers.emplace_back( std::thread(run_infer, models[i]) ); + } + + for(auto & worker : workers) { + worker.join(); + } + + return 0; + } else { + std::cerr << "arguments not right!" << std::endl; + std::cerr << "./ibnnet -s // serialize model to plan file" << std::endl; + std::cerr << "./ibnnet -d // deserialize plan file and run inference" << std::endl; + return -1; + } +} diff --git a/ibnnet/utils.cpp b/ibnnet/utils.cpp new file mode 100644 index 00000000..2ca8aa99 --- /dev/null +++ b/ibnnet/utils.cpp @@ -0,0 +1,39 @@ +#include "utils.h" + +// Load weights from files shared with TensorRT samples. +// TensorRT weight files have a simple space delimited format: +// [type] [size] +std::map loadWeights(const std::string file) { + std::cout << "Loading weights: " << file << std::endl; + std::map weightMap; + + // Open weights file + std::ifstream input(file); + assert(input.is_open() && "Unable to load weight file."); + + // Read number of weight blobs + int32_t count; + input >> count; + assert(count > 0 && "Invalid weight map file."); + + while (count--) { + Weights wt{DataType::kFLOAT, nullptr, 0}; + uint32_t size; + + // Read name and type of blob + std::string name; + input >> name >> std::dec >> size; + wt.type = DataType::kFLOAT; + + // Load blob + uint32_t* val = reinterpret_cast(malloc(sizeof(val) * size)); + for (uint32_t x = 0, y = size; x < y; ++x) { + input >> std::hex >> val[x]; + } + wt.values = val; + wt.count = size; + weightMap[name] = wt; + } + + return weightMap; +} diff --git a/ibnnet/utils.h b/ibnnet/utils.h new file mode 100644 index 00000000..fdd5be84 --- /dev/null +++ b/ibnnet/utils.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include "NvInfer.h" +#include "cuda_runtime_api.h" +#include "assert.h" +#include +#include +#include + +using namespace nvinfer1; + +#define CHECK(status) \ + do \ + { \ + auto ret = (status); \ + if (ret != 0) \ + { \ + std::cout << "Cuda failure: " << ret; \ + abort(); \ + } \ + } while (0) + +template +std::unique_ptr make_unique(Args&&... args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +std::map loadWeights(const std::string file); +