TorchScript 可以看作Python的一个子集,主要的应用场景是把Python/PyTorch代码转换成等价的C++代码从而提高深度学习模型在线上生产环境部署的运行效率。Python代码会被编译成TorchScript编译器可以理解的一种格式(ScriptModule),C++的生产环境可以载入该格式的文件并用内置的JIT来执行对应的代码。
TorchScript提供了两种方法来把Python代码转换成TorchScript representation,分别为:
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_module = MyModule(10,20)
sm = torch.jit.script(my_module)
在这里我们会从源代码层面去分析TorchScript是如何实现上面这两种方案的,并且每个方案自身的限制有哪些,来源是什么。
如上面的例子所示,tracing的应用场景是当你已经有了一个nn.Module (PyTorch里面定义神经网络的基本单元之一),你可以随便构造一个输入,然后告诉tracer:
非常简单直接,也跟这个名字非常吻合——在这个模型运行的过程中,有一台跟踪者在跟踪每一步的执行然后记录,因而得名tracer。
具体的实现方式,是通过在Operator的代码(也是C++)里面加上额外的追踪代码[1][2]。因为PyTorch的Module还有Operator为了执行效率本来就是C++的代码bind到了Python环境,所以在运行这个网络的过程之中自然会执行到Operator的C++源码,而其中就顺带执行了追踪代码。在追踪代码里面,每执行一个operator,就会往当前TracingState(定义成一个线程局部变量)里面的graph加入一个node。所有代码执行完毕,每一步的操作就会以一个Computation Graph里的某个节点的形式被保存下来。
(由于Python是单线程的,所以整个Computation Graph代码的执行顺序也是线性的,不用担心多线程带来的混乱。)
但是Tracing有如下限制:
跟实现方法一对照,我们很容易可以理解为什么有这些限制
Scripting,从上面的例子来看,似乎跟Tracing区别不大,但是其实现方法非常不一样。概括而言,scripting是通过把Python的源代码解析成语法树,然后转化成C++可执行代码来实现的。
因为是直接编译源代码,除了应用在nn.Module上面,script也可以直接被用来annotate一般的python class/function,并且可以支持条件语句等tracing不能处理的情况。但这也有缺点:现在的实现只能支持编译Python语法特定子集的代码,因此存在一部分的代码在tracing里可以work但在scripting这边由于编译器的限制不支持。
如下是scripting的实现细节(根据Python源代码的来源不同会有差别):
#1 从最简单的开始:如果代码来源是Python函数(def foo()),那么大致流程如下:
#2 如果需要转换的代码是一个类(class Foo(object)),大致流程跟Python函数的case差不多,不过有一些限制和细微差别:
#3 如果需要转换的代码来自于一个nn.Module (PyTorch里面用来定义神经网络的类)的实例,大致流程会相对复杂一点:
个人观点(也包括跟在PyTorch组工作的Engineer讨论得出的结论)