TensorRT学习笔记7 - 保存与读取序列化的结果

目录

  • 保存序列化的结果
  • 读取序列化的结果



       我们使用TensorRT转化Caffe模型时,每次都要使用如下代码将模型序列化,之后再进行反序列化,才可以进行inference。但是build engine的过程是十分耗时的,我们可以将序列化的结果保存至本地,以后直接读取本地的序列化结果即可,无需build engine,这样可以节省很多时间。

// 此函数用于得到pluginFactory和gie_model_stream
void objectDetection::caffeToGIEModel(
        const std::string& deploy_file,
        const std::string& model_file,
        const std::vector& outputs,
        unsigned int max_batch_size,
        nvcaffeparser1::IPluginFactory* pluginFactory,
        IHostMemory **gie_model_stream) {
        
    ...... //函数中的其他代码
    
    ICudaEngine *engine = builder->buildCudaEngine(*network);
    
    ...... //函数中的其他代码
    
    // 得到序列化的模型结果
    (*gie_model_stream) = engine->serialize();
    
    ...... //函数中的其他代码
}

// 调用caffeToGIEModel得到_plugin_factory和_gie_model_stream
caffeToGIEModel(_deploy_file, _model_file, _outputs, _batch_size, &_plugin_factory, &_gie_model_stream);

// 将序列化得到的结果进行反序列化,以执行后续的inference
_engine = _runtime->deserializeCudaEngine(_gie_model_stream->data(), _gie_model_stream->size(), &_plugin_factory);

保存序列化的结果

// 此函数用于得到pluginFactory和gie_model_stream
void objectDetection::caffeToGIEModel(
        const std::string& deploy_file,
        const std::string& model_file,
        const std::vector& outputs,
        unsigned int max_batch_size,
        nvcaffeparser1::IPluginFactory* pluginFactory,
        IHostMemory **gie_model_stream) {
        
    ...... //函数中的其他代码
    
    ICudaEngine *engine = builder->buildCudaEngine(*network);
    
    ...... //函数中的其他代码
    
    // 得到序列化的模型结果
    (*gie_model_stream) = engine->serialize();

    // 设置保存文件的名称为cached_model.bin
    std::string cache_path = "../cached_model.bin";
    std::ofstream serialize_output_stream;

    // 将序列化的模型结果拷贝至serialize_str字符串
    std::string serialize_str;
    serialize_str.resize( (*gie_model_stream)->size() );
    memcpy((void*)serialize_str.data(), (*gie_model_stream)->data(), (*gie_model_stream)->size());

    // 将serialize_str字符串的内容输出至cached_model.bin文件
    serialize_output_stream.open(cache_path);
    serialize_output_stream << serialize_str;
    serialize_output_stream.close();
    
    ...... //函数中的其他代码
}

读取序列化的结果

// 调用caffeToGIEModel得到_plugin_factory,将此函数中和_gie_model_stream相关的代码注释或者删除
caffeToGIEModel(_deploy_file, _model_file, _outputs, _batch_size, &_plugin_factory, &_gie_model_stream);

// 从cached_model.bin文件中读取序列化的结果
std::string cached_path = "../cached_model.bin";
std::ifstream fin(cached_path);

// 将文件中的内容读取至cached_engine字符串
std::string cached_engine = "";
while (fin.peek() != EOF){ // 使用fin.peek()防止文件读取时无限循环
    std::stringstream buffer;
    buffer << fin.rdbuf();
    cached_engine.append(buffer.str());
}
fin.close();

// 将序列化得到的结果进行反序列化,以执行后续的inference
_engine = _runtime->deserializeCudaEngine(cached_engine.data(), cached_engine.size(), &_plugin_factory);

你可能感兴趣的:(TensorRT)