[TensorRT Sample Code] How to Save Engine and Load Engine from file
发布时间
阅读量:
阅读量
In the TensorRT Sample "TensorRT-7.0.0.11/samples/trtexec/trtexec.cpp"
1 Save Engine to file
bool saveEngine(const ICudaEngine& engine, const std::string& fileName, std::ostream& err)
{
std::ofstream engineFile(fileName, std::ios::binary);
if (!engineFile)
{
err << "Cannot open engine file: " << fileName << std::endl;
return false;
}
TrtUniquePtr<IHostMemory> serializedEngine{engine.serialize()};
if (serializedEngine == nullptr)
{
err << "Engine serialization failed" << std::endl;
return false;
}
engineFile.write(static_cast<char*>(serializedEngine->data()), serializedEngine->size());
return !engineFile.fail();
}
2 Load Engine from file
TrtUniquePtr<nvinfer1::ICudaEngine> getEngine(const ModelOptions& model, const BuildOptions& build, const SystemOptions& sys, std::ostream& err)
{
TrtUniquePtr<nvinfer1::ICudaEngine> engine;
if (build.load)
{
engine.reset(loadEngine(build.engine, sys.DLACore, err));//.trt file
}
else
{
engine.reset(modelToEngine(model, build, sys, err));//other file such as .onnx
}
if (!engine)
{
err << "Engine creation failed" << std::endl;
return nullptr;
}
if (build.save && !saveEngine(*engine, build.engine, err))
{
err << "Saving engine to file failed" << std::endl;
return nullptr;
}
return engine;
}
2.1 Load Engine from .trt file
ICudaEngine* loadEngine(const std::string& engine, int DLACore, std::ostream& err)
{
std::ifstream engineFile(engine, std::ios::binary);
if (!engineFile)
{
err << "Error opening engine file: " << engine << std::endl;
return nullptr;
}
engineFile.seekg(0, engineFile.end);
long int fsize = engineFile.tellg();
engineFile.seekg(0, engineFile.beg);
std::vector<char> engineData(fsize);
engineFile.read(engineData.data(), fsize);
if (!engineFile)
{
err << "Error loading engine file: " << engine << std::endl;
return nullptr;
}
TrtUniquePtr<IRuntime> runtime{createInferRuntime(gLogger.getTRTLogger())};
if (DLACore != -1)
{
runtime->setDLACore(DLACore);
}
return runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr);
}
2.1 Load Engine from model file (caffe,onnx)
ICudaEngine* modelToEngine(
const ModelOptions& model, const BuildOptions& build, const SystemOptions& sys, std::ostream& err)
{
TrtUniquePtr<IBuilder> builder{createInferBuilder(gLogger.getTRTLogger())};
if (builder == nullptr)
{
err << "Builder creation failed" << std::endl;
return nullptr;
}
const bool isOnnxModel = model.baseModel.format == ModelFormat::kONNX;
auto batchFlag = (build.maxBatch && !isOnnxModel) ? 0U : 1U
<< static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
TrtUniquePtr<INetworkDefinition> network{builder->createNetworkV2(batchFlag)};
if (!network)
{
err << "Network creation failed" << std::endl;
return nullptr;
}
Parser parser = modelToNetwork(model, *network, err);
if (!parser)
{
err << "Parsing model failed" << std::endl;
return nullptr;
}
return networkToEngine(build, sys, *builder, *network, err);
}
Parser modelToNetwork(const ModelOptions& model, nvinfer1::INetworkDefinition& network, std::ostream& err)
{
Parser parser;
const std::string& modelName = model.baseModel.model;
switch (model.baseModel.format)
{
case ModelFormat::kCAFFE:
{
using namespace nvcaffeparser1;
parser.caffeParser.reset(createCaffeParser());
CaffeBufferShutter bufferShutter;
const auto blobNameToTensor = parser.caffeParser->parse(
model.prototxt.c_str(), modelName.empty() ? nullptr : modelName.c_str(), network, DataType::kFLOAT);
if (!blobNameToTensor)
{
err << "Failed to parse caffe model or prototxt, tensors blob not found" << std::endl;
parser.caffeParser.reset();
break;
}
for (const auto& s : model.outputs)
{
if (blobNameToTensor->find(s.c_str()) == nullptr)
{
err << "Could not find output blob " << s << std::endl;
parser.caffeParser.reset();
break;
}
network.markOutput(*blobNameToTensor->find(s.c_str()));
}
break;
}
...
}
ICudaEngine* networkToEngine(const BuildOptions& build, const SystemOptions& sys, IBuilder& builder,
INetworkDefinition& network, std::ostream& err)
{
TrtUniquePtr<IBuilderConfig> config{builder.createBuilderConfig()};
IOptimizationProfile* profile{nullptr};
if (build.maxBatch)
{
builder.setMaxBatchSize(build.maxBatch);
}
...
bool hasDynamicShapes{false};
for (unsigned int i = 0, n = network.getNbInputs(); i < n; i++)
{
// Set formats and data types of inputs
auto input = network.getInput(i);
if (!build.inputFormats.empty())
{
input->setType(build.inputFormats[i].first);
input->setAllowedFormats(build.inputFormats[i].second);
}
else
{
switch (input->getType())
{
case DataType::kINT32:
case DataType::kBOOL:
// Leave these as is.
break;
case DataType::kFLOAT:
case DataType::kINT8:
case DataType::kHALF:
// User did not specify a floating-point format. Default to kFLOAT.
input->setType(DataType::kFLOAT);
break;
}
input->setAllowedFormats(1U << static_cast<int>(TensorFormat::kLINEAR));
}
...
for (unsigned int i = 0, n = network.getNbOutputs(); i < n; i++)
{
// Set formats and data types of outputs
auto output = network.getOutput(i);
if (!build.outputFormats.empty())
{
output->setType(build.outputFormats[i].first);
output->setAllowedFormats(build.outputFormats[i].second);
}
else
{
output->setAllowedFormats(1U << static_cast<int>(TensorFormat::kLINEAR));
}
}
config->setMaxWorkspaceSize(static_cast<size_t>(build.workspace) << 20);
if (build.fp16)
{
config->setFlag(BuilderFlag::kFP16);
}
if (build.int8)
{
config->setFlag(BuilderFlag::kINT8);
}
auto isInt8 = [](const IOFormat& format) { return format.first == DataType::kINT8; };
auto int8IO = std::count_if(build.inputFormats.begin(), build.inputFormats.end(), isInt8)
+ std::count_if(build.outputFormats.begin(), build.outputFormats.end(), isInt8);
if ((build.int8 && build.calibration.empty()) || int8IO)
{
// Explicitly set int8 scales if no calibrator is provided and if I/O tensors use int8,
// because auto calibration does not support this case.
setTensorScales(network);
}
else if (build.int8)
{
config->setInt8Calibrator(new RndInt8Calibrator(1, build.calibration, network, err));
}
if (build.safe)
{
config->setEngineCapability(sys.DLACore != -1 ? EngineCapability::kSAFE_DLA : EngineCapability::kSAFE_GPU);
}
if (sys.DLACore != -1)
{
...
config->setDefaultDeviceType(DeviceType::kDLA);
config->setDLACore(sys.DLACore);
config->setFlag(BuilderFlag::kSTRICT_TYPES);
if (sys.fallback)
{
config->setFlag(BuilderFlag::kGPU_FALLBACK);
}
if (!build.int8)
{
config->setFlag(BuilderFlag::kFP16);
}
...
}
return builder.buildEngineWithConfig(network, *config);
}
全部评论 (0)
还没有任何评论哟~
