没办法, 出来混总是要还的, 不会写点底层代码没法混啊. 废话不多说, 简单来说, 有时候我们需要写一些自定义的操作, 这些操作如果用python写会很慢, 我们需要用CUDA写, 然后这些操作与python绑定, 以供python端调用.
主要是简略拿出 https://pytorch.org/tutorials/advanced/cpp_extension.html 的东西, 根据实践, 补充了一些东西(否则, 直接看官方文档可能会有一些地方需要花点实践), 没毛病. 看这个博客, 可以的.
示例程序
class LLTM(torch.nn.Module):
def __init__(self, input_features, state_size):
super(LLTM, self).__init__()
self.input_features = input_features
self.state_size = state_size
# 3 * state_size for input gate, output gate and candidate cell gate.
# input_features + state_size because we will multiply with [input, h].
self.weights = torch.nn.Parameter(
torch.empty(3 * state_size, input_features + state_size))
self.bias = torch.nn.Parameter(torch.empty(3 * state_size))
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.state_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, +stdv)
def forward(self, input, state):
old_h, old_cell = state
X = torch.cat([old_h, input], dim=1)
# 自定义C++扩展, 可以让这些操作, 变成一个fused的版本
# Compute the input, output and candidate cell gates with one MM.
gate_weights = F.linear(X, self.weights, self.bias)
# Split the combined gate weight matrix into its components.
gates = gate_weights.chunk(3, dim=1)
input_gate = F.sigmoid(gates[0])
output_gate = F.sigmoid(gates[1])
# Here we use an ELU instead of the usual tanh.
candidate_cell = F.elu(gates[2])
# Compute the new cell state.
new_cell = old_cell + candidate_cell * input_gate
# Compute the new hidden state and output.
new_h = F.tanh(new_cell) * output_gate
return new_h, new_cell
import torch
X = torch.randn(batch_size, input_features)
h = torch.randn(batch_size, state_size)
C = torch.randn(batch_size, state_size)
rnn = LLTM(input_features, state_size)
new_h, new_C = rnn(X, (h, C))
简单来说, 我们写完C++程序后, python要用这些程序, 可以用pybind11
. 然而, 安装pybind11
需要用到pytest
, 而pytest
貌似只能在python3.5以上才能运行. 所以我们先弄个基于python3的Anaconda. 装好pytorch之后(随便你咋装上的). 然后再进行后续操作.
git clone https://github.com/pybind/pybind11.git
pip install pytest
注意一下, 这里的pip -V
最好是显示anaconda3
中的pip
, 从而确保下载的pytest
是python3版本.
cd pybind11
mkdir build
cd build
cmake ..
make check -j 4
编译好的动态库是test
目录下的so文件.
我们要写pytorch扩展, 得下载pytorch源代码.
git clone --recursive https://github.com/pytorch/pytorch
然后我们在pytorch根目录下, 建立一个文件夹, 比如 lltm-extension.
并在该文件夹下, 建立setup.py
, 里面写
from setuptools import setup
from torch.utils.cpp_extension import CppExtension, BuildExtension
setup(
name='lltm',
ext_modules=[CppExtension('lltm', ['lltm.cpp'])],
cmdclass={'build_ext': BuildExtension})
这个是用来编译C++代码的.
然后在该目录下新建一个lltm.cpp
, 把下面代码贴上去.
注意, 这里用的是
值得注意的是, extension.h
里面主要包含了三种
#include
# include
std::vector<at::Tensor> lltm_forward(
at::Tensor input,
at::Tensor weights,
at::Tensor bias,
at::Tensor old_h,
at::Tensor old_cell) {
auto X = at::cat({old_h, input}, /*dim=*/1);
auto gate_weights = at::addmm(bias, X, weights.transpose(0, 1));
auto gates = gate_weights.chunk(3, /*dim=*/1);
auto input_gate = at::sigmoid(gates[0]);
auto output_gate = at::sigmoid(gates[1]);
auto candidate_cell = at::elu(gates[2], /*alpha=*/1.0);
auto new_cell = old_cell + candidate_cell * input_gate;
auto new_h = at::tanh(new_cell) * output_gate;
return {new_h,
new_cell,
input_gate,
output_gate,
candidate_cell,
X,
gate_weights};
}
// tanh'(z) = 1 - tanh^2(z)
at::Tensor d_tanh(at::Tensor z) {
return 1 - z.tanh().pow(2);
}
at::Tensor d_sigmoid(at::Tensor z) {
auto s = at::sigmoid(z);
return (1 - s) * s;
}
// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
at::Tensor d_elu(at::Tensor z, at::Scalar alpha = 1.0) {
auto e = z.exp();
auto mask = (alpha * (e - 1)) < 0;
return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
}
std::vector<at::Tensor> lltm_backward(
at::Tensor grad_h,
at::Tensor grad_cell,
at::Tensor new_cell,
at::Tensor input_gate,
at::Tensor output_gate,
at::Tensor candidate_cell,
at::Tensor X,
at::Tensor gate_weights,
at::Tensor weights) {
auto d_output_gate = at::tanh(new_cell) * grad_h;
auto d_tanh_new_cell = output_gate * grad_h;
auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;
auto d_old_cell = d_new_cell;
auto d_candidate_cell = input_gate * d_new_cell;
auto d_input_gate = candidate_cell * d_new_cell;
auto gates = gate_weights.chunk(3, /*dim=*/1);
d_input_gate *= d_sigmoid(gates[0]);
d_output_gate *= d_sigmoid(gates[1]);
d_candidate_cell *= d_elu(gates[2]);
auto d_gates =
at::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
auto d_weights = d_gates.t().mm(X);
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
auto d_X = d_gates.mm(weights);
const auto state_size = grad_h.size(1);
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
auto d_input = d_X.slice(/*dim=*/1, state_size);
return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
}
// 我们需要在最后加上这几行
// 从而将程序绑定到python端
PYBIND11_MODULE(lltm, m) {
m.def("forward", &lltm_forward, "LLTM forward");
m.def("backward", &lltm_backward, "LLTM backward");
}
此时我们的目录是这样的
pytorch/
lltm-extension/
lltm.cpp
setup.py
然后 python setup.py install
, 注意, 这里的setpu.py
是在lltm-extension
下的文件, 不是pytorch根目录下的那个文件.
然后大概是这样的, 我们可以看到, 这只是编译我们写的 lltm
文件, 其他都不会编译.并且编译好的python打包正好在anaconda里面, lltm-0.0.0-py3.6-linux-x86_64.egg
.
running install
running bdist_egg
running egg_info
writing lltm.egg-info/PKG-INFO
writing dependency_links to lltm.egg-info/dependency_links.txt
writing top-level names to lltm.egg-info/top_level.txt
reading manifest file 'lltm.egg-info/SOURCES.txt'
writing manifest file 'lltm.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
building 'lltm' extension
gcc -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I~/local/miniconda/lib/python3.6/site-packages/torch/lib/include -I~/local/miniconda/lib/python3.6/site-packages/torch/lib/include/TH -I~/local/miniconda/lib/python3.6/site-packages/torch/lib/include/THC -I~/local/miniconda/include/python3.6m -c lltm.cpp -o build/temp.linux-x86_64-3.6/lltm.o -DTORCH_EXTENSION_NAME=lltm -std=c++11
cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
g++ -pthread -shared -B ~/local/miniconda/compiler_compat -L~/local/miniconda/lib -Wl,-rpath=~/local/miniconda/lib -Wl,--no-as-needed -Wl,--sysroot=/ build/temp.linux-x86_64-3.6/lltm.o -o build/lib.linux-x86_64-3.6/lltm.cpython-36m-x86_64-linux-gnu.so
creating build/bdist.linux-x86_64/egg
copying build/lib.linux-x86_64-3.6/lltm_cuda.cpython-36m-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg
copying build/lib.linux-x86_64-3.6/lltm.cpython-36m-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg
creating stub loader for lltm.cpython-36m-x86_64-linux-gnu.so
byte-compiling build/bdist.linux-x86_64/egg/lltm.py to lltm.cpython-36.pyc
creating build/bdist.linux-x86_64/egg/EGG-INFO
copying lltm.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO
copying lltm.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying lltm.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying lltm.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt
zip_safe flag not set; analyzing archive contents...
__pycache__.lltm.cpython-36: module references __file__
creating 'dist/lltm-0.0.0-py3.6-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it
removing 'build/bdist.linux-x86_64/egg' (and everything under it)
Processing lltm-0.0.0-py3.6-linux-x86_64.egg
removing '~/local/miniconda/lib/python3.6/site-packages/lltm-0.0.0-py3.6-linux-x86_64.egg' (and everything under it)
creating ~/local/miniconda/lib/python3.6/site-packages/lltm-0.0.0-py3.6-linux-x86_64.egg
Extracting lltm-0.0.0-py3.6-linux-x86_64.egg to ~/local/miniconda/lib/python3.6/site-packages
lltm 0.0.0 is already the active version in easy-install.pth
Installed ~/local/miniconda/lib/python3.6/site-packages/lltm-0.0.0-py3.6-linux-x86_64.egg
Processing dependencies for lltm==0.0.0
Finished processing dependencies for lltm==0.0.0
我们conda list
时候, 发现里面有:
lltm 0.0.0 pypi_0 pypi
excellent, 变成一个包了, 所以可以直接import
了.
import torch
import lltm
lltm.forward
<function lltm.PyCapsule.forward>
help(lltm.forward)
forward(...) method of builtins.PyCapsule instance
forward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor) -> List[at::Tensor]
LLTM forward
在新建文件夹中, 建立文件LLTM_Module.py
import math
import torch
# Our module!
import lltm
class LLTMFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weights, bias, old_h, old_cell):
outputs = lltm.forward(input, weights, bias, old_h, old_cell)
new_h, new_cell = outputs[:2]
variables = outputs[1:] + [weights]
ctx.save_for_backward(*variables)
return new_h, new_cell
@staticmethod
def backward(ctx, grad_h, grad_cell):
outputs = lltm.backward(
grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables)
d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs
return d_input, d_weights, d_bias, d_old_h, d_old_cell
class LLTM(torch.nn.Module):
def __init__(self, input_features, state_size):
super(LLTM, self).__init__()
self.input_features = input_features
self.state_size = state_size
self.weights = torch.nn.Parameter(
torch.empty(3 * state_size, input_features + state_size))
self.bias = torch.nn.Parameter(torch.empty(3 * state_size))
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.state_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, +stdv)
def forward(self, input, state):
return LLTMFunction.apply(input, self.weights, self.bias, *state)
然后我们新建一个文件run_lltm.py
import torch
from LLTM_Module import LLTM
import time
assert torch.cuda.is_available()
cuda_device = torch.device("cuda")
batch_size = 16
input_features = 32
state_size = 128
X = torch.randn(batch_size, input_features, device=cuda_device)
h = torch.randn(batch_size, state_size, device=cuda_device)
C = torch.randn(batch_size, state_size, device=cuda_device)
rnn = LLTM(input_features, state_size).to(cuda_device)
forward = 0
backward = 0
for _ in range(100000):
start = time.time()
new_h, new_C = rnn(X, (h, C))
torch.cuda.synchronize()
forward += time.time() - start
start = time.time()
(new_h.sum() + new_C.sum()).backward()
torch.cuda.synchronize()
backward += time.time() - start
print('Forward: {:.3f} us | Backward {:.3f} us'.format(forward * 1e6/1e5, backward * 1e6/1e5))
我们可以看到其实lltm.cpp
就是直接将pytorch的api用ATen的api翻译了一下.这样测试了一下, 就会发现效果有提升. 值得注意的是:这样写的给予ATen的C++扩展可以同时适用与cpu和gpu的数据. , 把run_lltm.py
中的数据变成cpu的, 发现仍旧可以运行.
在python+GPU以及C++/ATen + GPU的实验效果如下:
Forward: 187.719 us | Backward 410.815 us
Forward: 149.802 us | Backward 393.458 us
啥, 就提升这么点?? 没劲! **其实, 我们刚才用的是ATen的, 书写比较简单, 我们其实可以用自定义
的cuda kernels
**来进一步加速.
看后面的博客.