作者:Accelemate
发布时间:2025年6月26日
本文摘要:本文将从零开始,系统性地讲解 PyTorch 中的计算图、反向传播、with torch.no_grad()
、.detach()
等核心机制,结合实践场景如可视化中间层特征图、GAN 模型中对生成器的冻结操作等内容,帮助你在实际开发中灵活、正确地使用自动微分特性。
PyTorch 的自动微分机制(Autograd)是其最核心的特性之一。
简而言之:
你只需要写出前向传播的过程(即如何从输入计算出输出),PyTorch 就会自动帮你记录每一步操作,构建出一个“计算图”,并且在调用
.backward()
时,沿着这个图,自动地反向传播误差,计算所有需要梯度的参数的梯度。
requires_grad
属性x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
requires_grad=True
时,计算图才会追踪这个张量。.backward()
到参数更新在模型训练过程中,每次前向传播得到 loss(损失值)后,调用:
loss.backward() # 反向传播,计算所有参数的梯度
optimizer.step() # 更新模型参数
optimizer.zero_grad() # 梯度清零,为下次迭代做准备
.backward()
是我们所说的“反向传播”核心函数,它的入口是 Loss。
你可能会问:我们不是要优化权重吗?为啥对 Loss 做反向传播?
✅ 答案是:
.backward()
就等于从 Loss 开始,自动向后走回网络的各层参数,使用链式法则一步步计算每个参数对 Loss 的梯度。这就是自动微分的强大之处。
requires_grad=True
的张量进行运算时,PyTorch 就会动态地记录这些操作,并构成一张有向无环图(DAG),这个图就叫做计算图。x = torch.tensor([2.0], requires_grad=True)
y = x ** 2 + 3 * x # 此时构建计算图
x → x**2 → +3x → y
的整个过程。✅ 不是创建张量时,而是在执行运算时。
x = torch.tensor([2.0], requires_grad=True) # 无计算图,仅是张量
z = x ** 2 + 1 # 此刻构建计算图
.grad_fn
与叶子节点.grad_fn
属性,标志其由哪种操作生成。tensor(...)
或 .detach()
创建的张量没有 .grad_fn
,这些是叶子节点(计算图的起点)。.backward()
y = (x ** 2).sum()
y.backward() # 自动计算 x 的梯度
print(x.grad) # tensor([2.0, 4.0, 6.0])
.backward()
的限制y.backward(gradient=torch.tensor([1.0, 1.0, 1.0]))
.backward()
的陷阱loss.backward(retain_graph=True)
.backward()
会释放计算图;多次反向需要设定 retain_graph=True
。with torch.no_grad()
:推理与节省资源model.eval()
with torch.no_grad():
output = model(input)
no_grad()
?即使你不写 .backward()
,PyTorch 仍会构建计算图,记录中间操作,浪费内存。
✅ 使用 no_grad()
可以:
.grad
。所以即便不训练,也应养成用 no_grad()
的好习惯。
.detach()
:从计算图中摘除变量x = torch.tensor([2.0], requires_grad=True)
y = x * 3
y_detached = y.detach()
y_detached
与 y
数据相同;y_detached.grad_fn = None
。❌ 不是。是“浅拷贝 + 剪断计算图”。
.detach()
返回的张量与原张量共享内存,但不再追踪其历史。features = model.conv1(input_image) # 中间层输出
features = features.detach() # 切断反向传播路径
为什么要可视化?看模型是否提取了我们期望的特征,比如边缘、纹理等:
import matplotlib.pyplot as plt
for i in range(6):
plt.imshow(features[0, i].cpu(), cmap='gray')
plt.title(f'Feature Map {i}')
plt.show()
fake_imgs = generator(z).detach()
d_real = discriminator(real_imgs)
d_fake = discriminator(fake_imgs)
x = x.detach() + torch.randn_like(x) * noise_level
.detach()
冻结操作生成对抗网络(GAN) = 生成器 + 判别器:
# 1. 训练 D
fake_imgs = generator(z).detach()
loss_D = loss_fn(discriminator(real_imgs), 1) + \
loss_fn(discriminator(fake_imgs), 0)
loss_D.backward()
# 2. 训练 G
fake_imgs = generator(z)
loss_G = loss_fn(discriminator(fake_imgs), 1)
loss_G.backward()
.detach()
手动断开反向路径。✅ 拆成两步训练,才能实现 G 和 D 的“对抗式交替训练”。
.backward()
的目标是 Loss?因为 Loss 是模型输出与目标之间差距的标量度量。
从 Loss 出发反推,才能知道“哪里错了”,以修正前面的参数。
不是。计算图在你使用 requires_grad=True
的张量进行运算时才开始构建。
.backward()
,为何还用 no_grad()
?虽然你不写 .backward()
,但不加 no_grad()
,PyTorch 仍然默默构建计算图,浪费内存。
no_grad()
是告诉系统:“我只要 forward,不要 graph 和 grad。”
.grad_fn
是不是就是 .grad
?不是。
.grad_fn
是这个张量是由哪个操作生成的(即图中是哪一个函数节点);.grad
是在反向传播时被填充的梯度值。概念/方法 | 含义 | 用法 | 应用场景 |
---|---|---|---|
requires_grad=True |
开启梯度追踪 | 定义张量 | 模型参数 |
.backward() |
启动反向传播 | loss.backward() |
从 loss 回溯计算各层梯度 |
torch.no_grad() |
关闭计算图追踪 | with torch.no_grad(): |
推理、冻结部分模块 |
.detach() |
手动剪断计算图 | y = x.detach() |
可视化中间层、GAN 冻结 G |
.grad_fn |
当前张量由哪个操作生成 | y.grad_fn |
理解计算图结构 |
.grad |
张量的梯度值 | x.grad |
反向传播后查看参数梯度 |
理解 PyTorch 的自动微分机制不仅仅是学会调用 .backward()
,更重要的是理解它背后是如何动态构建计算图、如何节省资源、如何控制信息流动的。尤其是在复杂任务如 GAN、NLP 微调中,合理使用 .detach()
和 no_grad()
是你进阶为深度学习开发者的标志之一。
希望这篇文章帮你从多个维度、多个场景完全吃透这些关键机制。如果觉得有帮助,欢迎点赞、收藏。