PyTorch 在 torch.nn 模块中提供了均方误差函数:
import torch.nn as nn
mse_loss = nn.MSELoss()
只需实现 forward() 方法,无需手动编写反向传播(自动求导引擎处理)。自定义损失函数类实例化后直接调用即可计算损失值。
继承 nn.Module 自定义均方误差损失函数的实现代码如下:
import torch.nn as nn
class MSELossV1(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, target):
squared_diff = (input - target) ** 2
n = squared_diff.numel()
return squared_diff.sum() / n
在 PyTorch 中,torch.autograd.Function 是一个用于定义自定义自动求导操作的类。它允许用户实现自定义的前向传播forward 和反向传播 backward 逻辑。这对于实现非标准操作、自定义激活函数、或在某些特殊场景中替代现有 PyTorch 操作非常有用。
torch.autograd.Function 实现自定义求导,需要实现 forward
和 backward
方法,这意味着需要自己手算反向传播求梯度公式。
ctx 是上下文对象,用于在 forward 和 backward 之间传递数据。常用方法是:
ctx.save_for_backward(*tensors)
:保存张量供反向传播使用ctx.saved_tensors
:获取保存的张量forward 方法返回计算结果,而 backward 返回对每个输入的梯度。
Function.apply(input)
是调用自定义函数的标准方式。继承 autograd.Function 自定义均方误差损失函数的实现代码如下:
import torch
from torch.autograd import Function
class MSELossV2(Function):
@staticmethod
def forward(ctx, input, target):
squared_diff = (input - target) ** 2
n = squared_diff.numel()
ctx.save_for_backward(input, target)
return squared_diff.sum() / n
@staticmethod
def backward(ctx, grad_output):
input, target = ctx.saved_tensors
n = input.numel()
grad_input = 2 / n * (input - target) * grad_output
return grad_input, None
在 PyTorch 的 torch.autograd.Function 中,backward 方法的返回值数量和顺序必须与 forward 方法的输入参数一一对应。例如,forward 传入的参数为 input 和 target,则 backward 也要返回两个梯度(例如 grad_input, None)。
每个输入参数都需要对应一个梯度输出:
backward 中的 grad_output 是一个张量,其形状与当前操作的输出张量一致。它表示在反向传播时,每个输出元素的梯度乘以一个
权重(即 grad_output 的值),从而影响输入梯度的计算。
grad_output 的使用总结如下:
场景 | grad_output 的作用 |
示例 |
---|---|---|
标量输出 | 默认为 1,无需显式指定 | loss.backward() |
向量输出 | 必须指定,形状与输出一致 | y.backward(torch.ones_like(y)) |
多输出 | 每个输出对应一个 grad_output |
grad_output=[v1, v2] |
自定义反向传播 | 传递上层梯度,计算输入梯度 | backward(ctx, grad_output) |
代码示例:
import torch
x = torch.tensor([2.0], requires_grad=True)
у = x**2
у.backward() # 等价于 y.backward(torch.tensor(1.0))
print(x.grad) # 输出 4.0 (dy/dx = 2x = 4)
x2 = torch.tensor([1.0, 2.0], requires_grad=True)
y = x2 * 2
grad_output = torch.tensor([1.0, 0.5]) # 权重分别为 1 和 0.5
y.backward(grad_output) # x2_grad = tensor([2., 1.]) (grad_output · dy/dx = [1.0, 0.5] · [2., 2.] = [2., 1.])
实验代码如下:
import torch
import torch.nn as nn
from torch.autograd import Function
class MSELossV1(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, target):
squared_diff = (input - target) ** 2
n = squared_diff.numel()
return squared_diff.sum() / n
class MSELossV2(Function):
@staticmethod
def forward(ctx, input, target):
squared_diff = (input - target) ** 2
n = squared_diff.numel()
ctx.save_for_backward(input, target)
return squared_diff.sum() / n
@staticmethod
def backward(ctx, grad_output):
input, target = ctx.saved_tensors
n = input.numel()
grad_input = 2 / n * (input - target) * grad_output
return grad_input, None
if __name__ == "__main__":
mse_loss = nn.MSELoss()
mse_loss_v2 = MSELossV1()
x = torch.tensor(
[[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]], requires_grad=True
)
x2 = x.detach().clone().requires_grad_(True)
x3 = x.detach().clone().requires_grad_(True)
y = torch.tensor(
[[0.5, 2.5, 2.0],
[3.5, 5.5, 5.0],
[6.5, 8.5, 8.0]]
)
loss = mse_loss(x, y)
loss2 = mse_loss_v2(x2, y)
loss3 = MSELossV2.apply(x3, y)
print(f"loss: {loss}, loss2: {loss2}, loss3: {loss3}")
loss.backward()
loss2.backward()
loss3.backward()
print(f"x.grad: \n{x.grad}\n x2.grad: \n{x2.grad}\n x3.grad: \n{x3.grad}")