Normally, pytorch can automatically achieve the gradient computation, i.e., autograd.
However, if we define a loss function including some complex operations, the autograd mechanics don't work.
Thus, adding these operations to autograd requires implementing a new Function subclass for each operation. Recall that Function s are what autograd uses to compute the results and gradients, and encode the operation history.
It can be achieved by using “Extending torch.autograd ”:
(1)__init__ (optional) - 如果你的operation包含非Variable参数,那么就将其作为__init__的参数传入到operation中。例如:AddConstant Function加一个常数,Transpose Function需要指定哪两个维度需要交换。如果你的operation不需要额外的参数,你可以忽略__init__。
(2)forward() - 计算 op 的前向过程
(3)backward() - 计算梯度
Demo:
import torch.autograd.Function as Function
class LinearFunction(Function):
# 创建torch.autograd.Function类的一个子类
# 必须是staticmethod
@staticmethod
# 第一个是ctx,第二个是input,其他是可选参数。
# ctx在这里类似self,ctx的属性可以在backward中调用。
# 自己定义的Function中的forward()方法,所有的Variable参数将会转成tensor!因此这里的input也是tensor.在传入forward前,autograd engine会自动将Variable unpack成Tensor。
def forward(ctx, input, weight, bias=None):
print(type(input))
ctx.save_for_backward(input, weight, bias) # 将Tensor转变为Variable保存到ctx中
output = input.mm(weight.t()) # torch.t()方法,对2D tensor进行转置
if bias is not None:
output += bias.unsqueeze(0).expand_as(output) #unsqueeze(0) 扩展处第0维
# expand_as(tensor)等价于expand(tensor.size()), 将原tensor按照新的size进行扩展
return output
@staticmethod
def backward(ctx, grad_output):
# grad_output为反向传播上一级计算得到的梯度值
input, weight, bias = ctx.saved_variables
grad_input = grad_weight = grad_bias = None
# 分别代表输入,权值,偏置三者的梯度
# 判断三者对应的Variable是否需要进行反向求导计算梯度
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight) # 复合函数求导,链式法则
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input) # 复合函数求导,链式法则
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight, grad_bias
#建议把新操作封装在一个函数中
def linear(input, weight, bias=None):
# First braces create a Function object. Any arguments given here
# will be passed to __init__. Second braces will invoke the __call__
# operator, that will then use forward() to compute the result and
# return it.
return LinearFunction()(input, weight, bias)#调用forward()
# 或者使用apply方法对自己定义的方法取个别名
linear = LinearFunction.apply
#检查实现的backward()是否正确
from torch.autograd import gradcheck
# gradchek takes a tuple of tensor as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (Variable(torch.randn(20,20).double(), requires_grad=True),)
test = gradcheck(LinearFunction(), input, eps=1e-6, atol=1e-4)
print(test) # 没问题的话输出True
参考文章链接:
https://blog.csdn.net/u012436149/article/details/78829329
https://blog.csdn.net/Hungryof/article/details/78346304
https://www.zhihu.com/question/66988664/answer/247952270