[LibTorch] C++ 调用 PyTorch 导出的模型

参考文章

  • C++部署pytorch模型
  • 利用LibTorch部署PyTorch模型
  • 官方文档

问题

pytorch 的神经网络模型有很多,但 libtorch 就特别少。现在面临的问题是要在 C++ 环境下应用神经网络模型,肯定不能直接使用 pytorch 模型。解决办法有两个:

  • 方法一是用 TorchScript 工具导出模型 poolnet.pt,模型中包含网络结构参数权重,因此可以直接在 C++ 里面生成神经网络。

  • 方法二是用 C++ 复现网络结构,封装为为类对象,再从 poolnet.pt 中导入参数权重

[LibTorch] C++ 调用 PyTorch 导出的模型_第1张图片

对于神经网络模型 PoolNet ,将其应用到 C++ 环境下进行视频处理,下面这是前 10 帧画面处理时间。明显看出,方法一前两次运行时间很长,从第三帧开始,两种方法的处理时间几乎相同。但是,方法一相当简单,导出模型即可,方法二需要复现网络结构,工程量巨大。下面重点介绍方法一

[LibTorch] C++ 调用 PyTorch 导出的模型_第2张图片

TorchScript 工具介绍

必定要看 官方文档。上面介绍了 trace 和 script 的区别。

PyTorch 导出模型

resnet50

编辑 export.py 文件,以 pytorch 提供的 resnet50 为例,分别使用 trace 和 script 导出模型。trace 需要提供一个输入样例,script 则不需要。但是复杂的模型使用 script 一般会失败,但 trace 可以。trace 和 script 导出的模型几乎没有区别,缺点是前两次处理时间都格外久

import torch
from torchvision.models import resnet50

net = resnet50(pretrained=True)
net = net.cuda()
net.eval()
for key, value in net.named_parameters(): 
	print(key)

# trace
x = torch.ones(1, 3, 224, 224)
x = x.cuda()
traced_module = torch.jit.trace(net, x)
traced_module.save("resnet50_trace.pt")

# script
scripted_module = torch.jit.script(net)
scripted_module.save("resnet50_script.pt")

在 python3+pytorch 的虚拟环境下执行

python export.py

net

上面的例子是 pytorch 提供的 resnet50,如果是自己写的模型,可以按照下面的方式来。其中,net.pth 是训练后保存的参数,net.pt 则是期望导出的模型,使用 trace 方法。

import torch 
import torchvision

# 初始化神经网络
net = Net()
net.load_state_dict(torch.load("net.pth"))
net.cuda()
net.eval()

# 导出模型
x = torch.ones(1, 3, 224, 224)
x = x.cuda()
m = torch.jit.trace(net, x)
m.save("net.pt")

其中

m.save("net.pt")

也可写为

torch.jit.save(m, "net.pt")

C++ 中调用模型

C++ 使用 libtorch 时,一般使用 CMake 进行管理(参考 Pytorch 官网教程)。下面是在 C++ 环境中调用模型的方法。

#include 
#include 

torch::Device device(torch::kCUDA);
// image.rows, image.cols 高在前,宽在后
torch::Tensor img_tensor = torch::from_blob(img.data, {1, image.rows, image.cols, 3}, torch::kByte).to(device);
img_tensor = img_tensor.permute({0, 3, 1, 2});
img_tensor = img_tensor.toType(torch::kFloat);
img_tensor = img_tensor.div(255.0);
torch::jit::script::Module net = torch::jit::load("../models/net.pt");
// 打印模型中的参数
for (const auto& pair : net.named_parameters()) {
    std::cout << pair.name << " " << pair.value.requires_grad() << std::endl;
}
net.to(device)
torch::NoGradGuard no_grad;
torch::Tensor output = net.forward({img_tensor}).toTensor();

你可能感兴趣的:(LibTorch,PyTorch)