Advertisement

[TensorRT Sample Code] How to Save Engine and Load Engine from file

阅读量:

In the TensorRT Sample "TensorRT-7.0.0.11/samples/trtexec/trtexec.cpp"

1 Save Engine to file

复制代码
 bool saveEngine(const ICudaEngine& engine, const std::string& fileName, std::ostream& err)

    
 {
    
     std::ofstream engineFile(fileName, std::ios::binary);
    
     if (!engineFile)
    
     {
    
         err << "Cannot open engine file: " << fileName << std::endl;
    
         return false;
    
     }
    
  
    
     TrtUniquePtr<IHostMemory> serializedEngine{engine.serialize()};
    
     if (serializedEngine == nullptr)
    
     {
    
         err << "Engine serialization failed" << std::endl;
    
         return false;
    
     }
    
  
    
     engineFile.write(static_cast<char*>(serializedEngine->data()), serializedEngine->size());
    
     return !engineFile.fail();
    
 }
    
    
    
    

2 Load Engine from file

复制代码
 TrtUniquePtr<nvinfer1::ICudaEngine> getEngine(const ModelOptions& model, const BuildOptions& build, const SystemOptions& sys, std::ostream& err)

    
 {
    
     TrtUniquePtr<nvinfer1::ICudaEngine> engine;
    
     if (build.load)
    
     {
    
     engine.reset(loadEngine(build.engine, sys.DLACore, err));//.trt file
    
     }
    
     else
    
     {
    
     engine.reset(modelToEngine(model, build, sys, err));//other file such as .onnx
    
     }
    
     if (!engine)
    
     {
    
     err << "Engine creation failed" << std::endl;
    
     return nullptr;
    
     }
    
     if (build.save && !saveEngine(*engine, build.engine, err))
    
     {
    
     err << "Saving engine to file failed" << std::endl;
    
     return nullptr;
    
     }
    
     return engine;
    
 }
    
    
    
    

2.1 Load Engine from .trt file

复制代码
 ICudaEngine* loadEngine(const std::string& engine, int DLACore, std::ostream& err)

    
 {
    
     std::ifstream engineFile(engine, std::ios::binary);
    
     if (!engineFile)
    
     {
    
     err << "Error opening engine file: " << engine << std::endl;
    
     return nullptr;
    
     }
    
  
    
     engineFile.seekg(0, engineFile.end);
    
     long int fsize = engineFile.tellg();
    
     engineFile.seekg(0, engineFile.beg);
    
  
    
     std::vector<char> engineData(fsize);
    
     engineFile.read(engineData.data(), fsize);
    
     if (!engineFile)
    
     {
    
     err << "Error loading engine file: " << engine << std::endl;
    
     return nullptr;
    
     }
    
  
    
     TrtUniquePtr<IRuntime> runtime{createInferRuntime(gLogger.getTRTLogger())};
    
     if (DLACore != -1)
    
     {
    
     runtime->setDLACore(DLACore);
    
     }
    
  
    
     return runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr);
    
 }
    
    
    
    

2.1 Load Engine from model file (caffe,onnx)

复制代码
 ICudaEngine* modelToEngine(

    
     const ModelOptions& model, const BuildOptions& build, const SystemOptions& sys, std::ostream& err)
    
 {
    
     TrtUniquePtr<IBuilder> builder{createInferBuilder(gLogger.getTRTLogger())};
    
     if (builder == nullptr)
    
     {
    
     err << "Builder creation failed" << std::endl;
    
     return nullptr;
    
     }
    
     const bool isOnnxModel = model.baseModel.format == ModelFormat::kONNX;
    
     auto batchFlag = (build.maxBatch && !isOnnxModel) ? 0U : 1U
    
     << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
    
     TrtUniquePtr<INetworkDefinition> network{builder->createNetworkV2(batchFlag)};
    
     if (!network)
    
     {
    
     err << "Network creation failed" << std::endl;
    
     return nullptr;
    
     }
    
     Parser parser = modelToNetwork(model, *network, err);
    
     if (!parser)
    
     {
    
     err << "Parsing model failed" << std::endl;
    
     return nullptr;
    
     }
    
  
    
     return networkToEngine(build, sys, *builder, *network, err);
    
 }
    
    
    
    
复制代码
 Parser modelToNetwork(const ModelOptions& model, nvinfer1::INetworkDefinition& network, std::ostream& err)

    
 {
    
     Parser parser;
    
     const std::string& modelName = model.baseModel.model;
    
     switch (model.baseModel.format)
    
     {
    
     case ModelFormat::kCAFFE:
    
     {
    
     using namespace nvcaffeparser1;
    
     parser.caffeParser.reset(createCaffeParser());
    
     CaffeBufferShutter bufferShutter;
    
     const auto blobNameToTensor = parser.caffeParser->parse(
    
         model.prototxt.c_str(), modelName.empty() ? nullptr : modelName.c_str(), network, DataType::kFLOAT);
    
     if (!blobNameToTensor)
    
     {
    
         err << "Failed to parse caffe model or prototxt, tensors blob not found" << std::endl;
    
         parser.caffeParser.reset();
    
         break;
    
     }
    
  
    
     for (const auto& s : model.outputs)
    
     {
    
         if (blobNameToTensor->find(s.c_str()) == nullptr)
    
         {
    
             err << "Could not find output blob " << s << std::endl;
    
             parser.caffeParser.reset();
    
             break;
    
         }
    
         network.markOutput(*blobNameToTensor->find(s.c_str()));
    
     }
    
     break;
    
     }
    
     ...
    
 }
    
    
    
    
复制代码
 ICudaEngine* networkToEngine(const BuildOptions& build, const SystemOptions& sys, IBuilder& builder,

    
     INetworkDefinition& network, std::ostream& err)
    
 {
    
     TrtUniquePtr<IBuilderConfig> config{builder.createBuilderConfig()};
    
  
    
     IOptimizationProfile* profile{nullptr};
    
     if (build.maxBatch)
    
     {
    
     builder.setMaxBatchSize(build.maxBatch);
    
     }
    
     ...
    
     bool hasDynamicShapes{false};
    
     for (unsigned int i = 0, n = network.getNbInputs(); i < n; i++)
    
     {
    
     // Set formats and data types of inputs
    
     auto input = network.getInput(i);
    
     if (!build.inputFormats.empty())
    
     {
    
         input->setType(build.inputFormats[i].first);
    
         input->setAllowedFormats(build.inputFormats[i].second);
    
     }
    
     else
    
     {
    
         switch (input->getType())
    
         {
    
         case DataType::kINT32:
    
         case DataType::kBOOL:
    
             // Leave these as is.
    
             break;
    
         case DataType::kFLOAT:
    
         case DataType::kINT8:
    
         case DataType::kHALF:
    
             // User did not specify a floating-point format.  Default to kFLOAT.
    
             input->setType(DataType::kFLOAT);
    
             break;
    
         }
    
         input->setAllowedFormats(1U << static_cast<int>(TensorFormat::kLINEAR));
    
     }
    
  
    
     ...
    
  
    
     for (unsigned int i = 0, n = network.getNbOutputs(); i < n; i++)
    
     {
    
     // Set formats and data types of outputs
    
     auto output = network.getOutput(i);
    
     if (!build.outputFormats.empty())
    
     {
    
         output->setType(build.outputFormats[i].first);
    
         output->setAllowedFormats(build.outputFormats[i].second);
    
     }
    
     else
    
     {
    
         output->setAllowedFormats(1U << static_cast<int>(TensorFormat::kLINEAR));
    
     }
    
     }
    
  
    
     config->setMaxWorkspaceSize(static_cast<size_t>(build.workspace) << 20);
    
  
    
     if (build.fp16)
    
     {
    
     config->setFlag(BuilderFlag::kFP16);
    
     }
    
  
    
     if (build.int8)
    
     {
    
     config->setFlag(BuilderFlag::kINT8);
    
     }
    
  
    
     auto isInt8 = [](const IOFormat& format) { return format.first == DataType::kINT8; };
    
     auto int8IO = std::count_if(build.inputFormats.begin(), build.inputFormats.end(), isInt8)
    
     + std::count_if(build.outputFormats.begin(), build.outputFormats.end(), isInt8);
    
  
    
     if ((build.int8 && build.calibration.empty()) || int8IO)
    
     {
    
     // Explicitly set int8 scales if no calibrator is provided and if I/O tensors use int8,
    
     // because auto calibration does not support this case.
    
     setTensorScales(network);
    
     }
    
     else if (build.int8)
    
     {
    
     config->setInt8Calibrator(new RndInt8Calibrator(1, build.calibration, network, err));
    
     }
    
  
    
     if (build.safe)
    
     {
    
     config->setEngineCapability(sys.DLACore != -1 ? EngineCapability::kSAFE_DLA : EngineCapability::kSAFE_GPU);
    
     }
    
  
    
     if (sys.DLACore != -1)
    
     {
    
     ...
    
         config->setDefaultDeviceType(DeviceType::kDLA);
    
         config->setDLACore(sys.DLACore);
    
         config->setFlag(BuilderFlag::kSTRICT_TYPES);
    
  
    
         if (sys.fallback)
    
         {
    
             config->setFlag(BuilderFlag::kGPU_FALLBACK);
    
         }
    
         if (!build.int8)
    
         {
    
             config->setFlag(BuilderFlag::kFP16);
    
         }
    
     ...
    
     }
    
  
    
     return builder.buildEngineWithConfig(network, *config);
    
 }
    
    
    
    

全部评论 (0)

还没有任何评论哟~