PyTorch torch.no_grad() 指南(笔记)

PyTorch torch.no_grad() 权威

在PyTorch深度学习框架中,高效的显存管理对于训练复杂模型和执行大规模推理任务至关重要。显存不足(Out Of Memory, OOM)错误是开发者经常面临的挑战之一 。torch.no_grad() 作为PyTorch提供的一个核心工具,能够在推理(inference)和验证(validation)阶段显著优化显存使用并提升计算速度。本报告旨在全面、深入地剖析 torch.no_grad() 的工作原理、使用方法、对显存和性能的具体影响,并通过实例代码展示其应用,同时指出常见的使用陷阱与误区,帮助读者更精确、自信地运用这一工具。

一、 PyTorch中的显存管理与torch.no_grad()的重要性

PyTorch凭借其灵活性和动态图机制,在学术界和工业界得到了广泛应用。然而,随着模型规模的增大和数据量的激增,GPU显存往往成为瓶颈。训练过程中,为了进行反向传播和参数更新,PyTorch需要构建计算图并存储中间变量(激活值),这会消耗大量显存。在模型推理或评估阶段,这些梯度计算和历史追踪通常是不必要的。此时,若不加以控制,依然会产生不必要的显存开销,甚至导致OOM错误,中断程序运行 。

torch.no_grad() 正是PyTorch为解决这一问题提供的关键机制。它能够临时关闭自动求导引擎的部分功能,从而在不需要计算梯度的代码块中,避免存储计算历史和中间激活值,达到节省显存、加速计算的目的。理解并正确使用 torch.no_grad() 对于编写高效、稳健的PyTorch推理代码至关重要。本报告将系统性地阐述torch.no_grad()的各个方面,旨在为PyTorch用户提供一份详尽的实践指南。

II. 理解PyTorch的自动求导机制 (Autograd)

为了充分理解 torch.no_grad() 的作用,首先需要对PyTorch的自动求导机制(Autograd)有一个清晰的认识。Autograd是PyTorch实现神经网络训练的核心。

A. 计算图 (Computation Graph) 基础

PyTorch采用动态计算图(Dynamic Computation Graph)的方式进行运算。当执行张量(Tensor)操作时,如果这些操作需要被追踪(通常是为了后续的梯度计算),PyTorch会即时构建一个有向无环图 (DAG) 。在这个图中:

  • 节点 (Nodes) 代表张量。
  • 边 (Edges) 代表产生输出张量的操作(函数)。

这个计算图记录了数据是如何从输入张量一步步计算得到输出张量的。例如,如果 y=x1​+x2​,z=y∗x3​,那么 x1​,x2​,x3​ 是叶子节点, y,z 是中间节点或根节点。这个图的关键作用在于反向传播:当计算损失函数关于模型参数的梯度时,Autograd会从表示损失的节点开始,沿着图反向追溯,利用链式法则计算并累积梯度至各个叶子节点(通常是模型的参数)。

PyTorch计算图的动态性意味着图是在运行时定义的,并且在每次迭代(例如,每个训练批次)后都可以从头开始重新构建 。这种特性为使用Python的控制流语句(如循环和条件判断)提供了极大的灵活性,因为图的结构可以根据运行时的条件而改变。然而,这也意味着如果某些操作持续被追踪而没有及时释放,计算图可能会变得非常庞大,从而消耗大量内存。

B. requires_grad 属性

张量的 requires_grad 属性是控制Autograd是否追踪其操作的关键。

  • 默认情况下,新创建的张量的 requires_gradFalse
  • 如果一个张量是模型的参数(即 torch.nn.Parameter 的实例),那么它的 requires_grad 默认为 True
  • 可以通过 tensor.requires_grad_(True) 或在创建时指定 requires_grad=True 来显式设置。

当一个操作至少有一个输入张量的 requires_gradTrue 时,其输出张量通常也会具有 requires_grad=True,并且该操作会被记录在计算图中,以便后续的梯度计算 。

C. Autograd 如何追踪历史

如果一个张量的 requires_gradTrue,并且它是由某个操作产生的,那么它会拥有一个 .grad_fn 属性。这个属性指向一个特殊的对象(如 AddBackward0, MulBackward0 等),该对象记录了创建该张量的函数以及其输入 。这些 .grad_fn 对象相互连接,构成了从输出回溯到输入的路径,即计算图的一部分。

在执行 .backward() 操作时(通常在损失张量上调用),Autograd引擎会:

  1. 从调用 .backward() 的张量开始,沿着 .grad_fn 链反向遍历计算图。
  2. 计算每个操作相对于其输入的梯度。
  3. 将计算得到的梯度累积到相应张量的 .grad 属性中(仅对 requires_grad=True 的叶子节点有效)。

为了完成梯度计算,Autograd需要在前向传播过程中存储一些中间结果,即激活值(activations)。例如,对于乘法操作 z=x×y,计算 ∂x∂L​ 需要 y 和 ∂z∂L​。这些中间激活值的存储是导致训练过程中显存消耗较大的主要原因之一 。正是因为Autograd需要存储这些用于梯度计算的中间激活值,我们才需要像torch.no_grad()这样的机制来在不需要梯度时避免这种存储,从而节省显存。

III. 深入剖析 torch.no_grad()

torch.no_grad() 是PyTorch中用于控制梯度计算上下文的一个重要工具。它可以作为上下文管理器(context manager)或装饰器(decorator)使用 。

A. 定义与核心功能

torch.no_grad() 的核心功能是在其作用域内临时禁用梯度计算和历史追踪 。这意味着,在 with torch.no_grad(): 代码块内或被 @torch.no_grad() 装饰的函数内执行的所有PyTorch张量运算,都不会被Autograd引擎记录到计算图中。

B. torch.no_grad() 如何影响计算图和 requires_grad

当进入 torch.no_grad() 上下文后:

  1. 不构建计算图:在该上下文中执行的任何操作都不会创建新的计算图节点或边 。即使输入张量本身具有 requires_grad=True 的属性,这些操作也不会被记录。
  2. 输出张量的 requires_grad 属性:在 torch.no_grad() 块内,任何计算的结果张量,其 requires_grad 属性都会被设置为 False,即使所有输入张量的 requires_grad 都为 True 7。这确保了这些输出张量不会参与后续的梯度计算。

一个重要的例外是工厂函数(factory functions)。如果一个函数用于创建新的张量,并且接受一个 requires_grad 关键字参数(例如 torch.nn.Parameter()torch.ones(..., requires_grad=True)),那么即使在 no_grad 上下文中,这些函数创建的张量如果被显式指定 requires_grad=True,其 requires_grad 属性仍然会是 True 。然而,对此类张量进行的后续操作(仍在 no_grad 上下文内)所产生的输出,其 requires_grad 仍将为 False。这种设计确保了参数初始化等场景的灵活性,同时不影响 no_grad 上下文对后续计算的梯度追踪控制。

C. 显存节省机制

torch.no_grad() 节省显存的主要机制在于它阻止了Autograd引擎存储用于反向传播的中间激活值 。如前所述,在正常的训练模式下(即梯度计算开启时),前向传播过程中产生的许多中间张量需要被保留在内存中,因为它们是计算梯度所必需的。当模型非常深或批次较大时,这些中间激活值会占据大量的GPU显存。

torch.no_grad() 上下文中,由于PyTorch知道不需要进行梯度计算,因此它不会保留这些中间激活值。一旦某个中间张量不再被后续的(无梯度)计算所直接需要,其占用的内存就可以被更快地回收或重用。这在模型推理、验证或任何不需要梯度更新的计算场景中,能够显著降低峰值显存占用 。

D. 性能提升机制

除了节省显存,torch.no_grad() 还能带来性能上的提升,主要体现在以下几个方面:

  1. 减少Autograd开销:Autograd引擎在追踪操作、构建计算图以及管理相关元数据时会引入一定的计算开销。禁用梯度计算可以避免这部分开销 。
  2. 更快的计算:由于不需要为反向传播准备数据(如存储中间激活值),某些操作本身可能会执行得更快。尤其是在计算图构建本身成为性能瓶颈时,torch.no_grad() 的效果更为明显。

总而言之,torch.no_grad() 通过阻止Autograd追踪新的操作历史,使得新产生的张量不附带梯度信息,从而避免了为这些新操作存储中间激活值。这不仅直接减少了内存消耗,也因减少了Autograd的簿记工作而可能加速计算。它影响的是在其作用域内发生的操作,而不是永久改变输入张量本身的 requires_grad 状态(输入张量若原本 requires_grad=True,在离开 no_grad 作用域后,参与新的、有梯度追踪的操作时,依然会按其原有属性行事)。

IV. torch.no_grad() 实战演练:示例代码

以下通过几个具体示例来演示 torch.no_grad() 的实际用法。

A. 基本用法:上下文管理器

上下文管理器是 torch.no_grad() 最常见的用法。它清晰地界定了禁用梯度计算的代码范围。

import torch

# 创建一个需要梯度的张量 x
x = torch.tensor([1., 2., 3.], requires_grad=True)
print(f"Original x.requires_grad: {x.requires_grad}")

# 在 torch.no_grad() 上下文管理器内部进行操作
with torch.no_grad():
    y = x * 2  # 对 x 进行操作得到 y
    # 此时,y 是在 no_grad 上下文中创建的,因此 y.requires_grad 会是 False
    print(f"Inside no_grad: y.requires_grad: {y.requires_grad}")

    # 任何在此上下文中进行的操作,其输出都不会追踪梯度
    z = torch.matmul(y, x.T) # 即使 x 仍然是 requires_grad=True, 但 matmul 的输出 z 不会追踪梯度
    print(f"Inside no_grad: z.requires_grad: {z.requires_grad}")

# 离开 no_grad 上下文后
# y 和 z 是在 no_grad 块内创建的,它们的 requires_grad 已经是 False
print(f"Outside no_grad: y.requires_grad: {y.requires_grad}")
print(f"Outside no_grad: z.requires_grad: {z.requires_grad}")

# 如果我们用原始的 x (requires_grad=True) 进行新的操作,梯度会被追踪
# w = x + 10
# print(f"Outside no_grad, new op on x: w.requires_grad: {w.requires_grad}") # 这会是 True

在这个例子中,尽管输入张量 x 设置了 requires_grad=True,但在 with torch.no_grad(): 块内执行的乘法操作 x * 2 所产生的张量 y,其 requires_grad 属性被设置为 False 。同样,后续使用 yx 计算得到的 z 也具有 requires_grad=False

B. 典型的模型推理/验证循环

在模型评估(验证或测试)阶段,通常不需要计算梯度。因此,将整个评估循环置于 torch.no_grad() 的上下文中是标准做法。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 示例模型定义
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(5, 1)
        self.dropout = nn.Dropout(0.5) # 包含 Dropout 层以演示 model.eval() 的作用

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x) # Dropout 在训练和评估时行为不同
        x = self.linear2(x)
        return x

# 假设 model 和 valid_dataloader 已经定义
# model = SimpleModel()
# # 假设有一些训练好的权重
# # model.load_state_dict(torch.load('model_weights.pth'))
#
# # 创建虚拟数据加载器
# dummy_data = torch.randn(100, 10)
# dummy_labels = torch.randn(100, 1)
# valid_dataset = TensorDataset(dummy_data, dummy_labels)
# valid_dataloader = DataLoader(valid_dataset, batch_size=16)
#
# criterion = nn.MSELoss()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)

# --------- 实际演示代码 ---------
model = SimpleModel() # 实例化模型
# 创建一个虚拟的验证数据加载器
X_valid = torch.randn(64, 10)
y_valid = torch.randn(64, 1)
valid_dataset = TensorDataset(X_valid, y_valid)
valid_dataloader = DataLoader(valid_dataset, batch_size=32)
criterion = nn.MSELoss() # 损失函数

# 关键步骤1: 将模型设置为评估模式
model.eval() # 这会影响 Dropout 和 BatchNorm 等层的行为

total_loss = 0
num_samples = 0

# 关键步骤2: 使用 torch.no_grad() 包裹整个验证循环
with torch.no_grad():
    for inputs, labels in valid_dataloader:
        # inputs, labels = inputs.to(device), labels.to(device) # 如果使用GPU

        outputs = model(inputs)  # 前向传播
        # 此时 outputs.requires_grad 会是 False

        loss = criterion(outputs, labels) # 计算损失 (对于纯推理,此步骤可选)
        # 此时 loss.requires_grad 会是 False

        # 使用.item() 获取Python标量值,避免累积张量历史
        # 即使在 no_grad() 中 loss.requires_grad 已为 False,
        #.item() 仍然是好习惯,因为它更节省内存且是训练中累积loss的标准做法。
        total_loss += loss.item() * inputs.size(0)
        num_samples += inputs.size(0)

avg_loss = total_loss / num_samples
print(f"Validation - Average Loss: {avg_loss:.4f}")

这里有两个关键点:

  1. model.eval():在评估循环开始前调用。这会通知模型中所有模块(特别是如 DropoutBatchNorm 这类在训练和评估时行为不同的模块)切换到评估模式 。例如,Dropout 层在评估模式下会被禁用,BatchNorm 层会使用其在训练期间学习到的运行均值和方差,而不是当前批次的统计数据。
  2. with torch.no_grad()::包裹整个循环体。这确保了在前向传播和损失计算过程中不会构建计算图,从而节省显存并可能加速计算。

关于 loss.item() 的使用:在此 no_grad 上下文中,outputs = model(inputs) 产生的 outputsrequires_grad 属性为 False。因此,后续计算得到的 loss = criterion(outputs, labels)requires_grad 属性也将为 False。从理论上讲,直接累加 loss 张量(total_loss += loss)在此特定情境下不会累积计算图历史,因为 loss 本身就不带历史。然而,使用 loss.item() 将一个零维张量转换为Python标量数字,是一种更通用的良好实践。这不仅因为Python数字比零维张量对象更轻量,更重要的是,这种做法源于训练循环中的经验:在训练时,loss 张量会带有计算历史(requires_grad=True),如果不使用 .item().detach(),累积 loss 会导致显存泄漏 。因此,在验证循环中也使用 .item() 是一种稳健的习惯,确保了代码的一致性和对显存最友好的处理方式。这种做法体现了PyTorch编程的一种“卫生习惯”,即将 model.eval()torch.no_grad().item() 结合使用,以确保评估代码的正确性、高效性和稳健性。

C. 使用 @torch.no_grad() 作为装饰器

对于专门用于推理或评估的辅助函数,可以使用 @torch.no_grad() 装饰器,使整个函数体都在无梯度上下文中执行。

import torch
import torch.nn as nn

# 假设模型已定义并加载权重
# model = MyModel().to(device)
# model.eval() # 别忘了在调用前设置评估模式

# --------- 实际演示代码 ---------
model = nn.Linear(10, 1) # 虚拟模型
model.eval() # 设置为评估模式

@torch.no_grad()  # 装饰器将 no_grad 应用于整个函数
def predict_fn(input_tensor, model_instance):
    # input_tensor = input_tensor.to(device) # 如果使用GPU
    output = model_instance(input_tensor)
    # 此时 output.requires_grad 会是 False

    # 在这里可以进行任何后处理操作
    # 例如:probabilities = torch.softmax(output, dim=1)
    return output # return probabilities

# 示例用法
sample_input = torch.randn(1, 10)
prediction = predict_fn(sample_input, model)
print(f"Prediction: {prediction}")
print(f"Prediction requires_grad: {prediction.requires_grad}")

这种方式非常适合封装独立的预测逻辑,使得代码更简洁。当调用被 @torch.no_grad() 装饰的函数时,其内部所有PyTorch运算都不会追踪梯度 。

V. 精通 torch.no_grad() 进行推理:有效节省GPU显存

在进行模型推理时,合理使用 torch.no_grad() 是节省GPU显存的关键策略。

A. 何处添加 torch.no_grad()

指导原则:将 torch.no_grad() 应用于代码中任何执行模型推理或评估且不需要计算梯度的部分。

典型场景包括

  • 完整的验证或测试循环:如IV.B节示例所示,迭代处理 DataLoader 中的数据,进行前向传播并计算评估指标。
  • 专用的预测函数:如IV.C节示例所示,接收输入数据并返回模型输出的函数。
  • 模型输出的后处理步骤:如果这些后处理步骤涉及到PyTorch张量运算,并且这些运算在默认情况下可能会追踪梯度(例如,如果输入到后处理的张量 requires_grad=True),那么也应该将它们包含在 no_grad 上下文中。

目标是阻止Autograd引擎为推理流程的任何部分存储中间激活或构建计算历史。

B. 添加一次就足够了吗?理解作用域

torch.no_grad() 的作用效果严格限于其定义的作用域(scope)。它是一个上下文管理器或装饰器,其影响不会泄露到其定义的代码块之外 。

  • 对于单一、连贯的推理阶段:如果整个推理逻辑(例如,遍历数据加载器、获取模型输出、计算指标)都包含在一个 with torch.no_grad(): 块内,或者一个由 @torch.no_grad() 装饰的函数内,那么是的,将其应用于该整个作用域一次就足够了。该作用域内的所有PyTorch操作都将受益。

    # 假设 model 和 test_loader 已定义
    model.eval()
    predictions_list =
    # 整个推理过程使用一个 no_grad 块
    with torch.no_grad():
        for data_batch in test_loader:
            # data_batch = data_batch.to(device)
            outputs = model(data_batch)
            # outputs.requires_grad 为 False
            # processed_outputs = some_pytorch_postprocessing(outputs) # 同样不追踪梯度
            # predictions_list.append(processed_outputs.cpu())
    # 从模型前向传播到后处理的所有操作都被覆盖。
    
  • 对于分段的推理逻辑:如果推理过程被分解为多个独立的Python函数或代码块,并且它们是独立执行的(即没有嵌套在单个 no_grad 上下文中),那么每个执行PyTorch操作并希望禁用梯度的段落,都需要其自身的 torch.no_grad() 包装。

    def get_embeddings(model, data_batch):
        # 这个函数可能从不同地方调用,
        # 有些调用可能需要梯度,有些则不需要。
        # data_batch = data_batch.to(device)
        return model.feature_extractor(data_batch)
    
    #... 在后续的特定于推理的代码中...
    model.eval()
    # inference_sample = inference_sample.to(device)
    # 此处需要 no_grad 来进行本次特定的推理调用
    with torch.no_grad():
        raw_features = get_embeddings(model, inference_sample)
        # raw_features.requires_grad 为 False
        final_prediction = model.classifier(raw_features)
        # final_prediction.requires_grad 为 False
    

可以将其类比为房间的电灯开关。当你按下开关(进入 no_grad 上下文),房间内(作用域内)的灯(梯度追踪)就关闭了。当你离开房间(退出上下文),外面的灯不受影响。如果你进入另一个房间(另一个代码块),如果想让那里的灯也关闭,就需要操作那个房间的开关。

许多用户可能错误地认为 torch.no_grad() 会以某种方式永久性地“标记”张量,或者改变PyTorch的全局状态。澄清其作用域的限制性至关重要,这有助于防止此类误解,并帮助调试意外的梯度追踪行为。torch.no_grad() 并不会永久改变输入张量 x 本身的 requires_grad 状态;它只改变在上下文中对 x 进行操作所产生的输出的梯度追踪行为。这表明 torch.no_grad() 是对一段代码的临时状态改变,而不是对张量或全局Autograd引擎状态的永久修改 。

C. 可视化影响 (概念层面)

若不使用 torch.no_grad(),推理过程的显存占用图会显示出因存储激活值而产生的峰值,类似于训练过程中的显存图 。这些激活值是为潜在的反向传播而保留的。

当使用 torch.no_grad() 时,这些与激活值相关的显存分配大部分会消失,从而在前向传播过程中形成一个更平坦、更低的显存占用曲线 。此时,主要的显存消耗将来自于模型参数本身以及当前处理批次的输入和输出张量。

VI. 注意!常见陷阱与误解

虽然 torch.no_grad() 功能强大且易于使用,但在实践中仍存在一些常见的陷阱和误解,可能导致其被误用或效果不佳。

A. torch.no_grad() vs. model.eval():不可互换!

这是初学者乃至一些有经验用户非常容易混淆的一点 。两者都常用于“评估阶段”,这造成了语义上的重叠。然而,它们的作用目标和机制完全不同。

  • model.eval():

    • 目的:将模型设置为评估模式。这主要影响模型中特定层的行为,如 torch.nn.Dropout(在评估时禁用 dropout)和 torch.nn.BatchNorm2d(在评估时使用训练阶段学习到的运行均值和方差,而不是当前批次的统计数据)。
    • 影响:改变特定层的行为,使其适合推理或测试。它本身并禁用梯度计算。
  • with torch.no_grad()::

    • 目的:禁用Autograd引擎的梯度计算和历史追踪功能 。
    • 影响:通过不构建用于反向传播的计算图,节省显存并可能加速计算。
  • 为何推理时两者都需要

    • 使用 model.eval() 来确保模型的各层(如Dropout, BatchNorm)表现出正确的、确定性的推理行为。
    • 使用 with torch.no_grad(): 来避免不必要的梯度计算和显存消耗。

下表清晰地总结了 model.eval()with torch.no_grad(): 的区别:

特性 model.eval() with torch.no_grad():
主要目的 将模块设置为评估模式。 禁用Autograd的梯度计算。
影响对象 Dropout, BatchNorm等特定层的行为。 Autograd引擎;新操作输出的 requires_grad 属性。
梯度计算 不直接禁用梯度计算。 禁用梯度计算和历史追踪。
显存节省 间接节省(如果层在评估模式下行为更简单)。 直接且显著地节省,因不存储用于梯度的激活值。
速度提升 可能(如果层在评估模式下计算更少)。 直接地提升,因减少了Autograd的开销。
典型用法 在任何评估/推理阶段开始前调用一次。 包裹评估/推理代码块,或装饰相关函数。

这种混淆的根源在于两者都用于“评估阶段”。model.eval() 针对的是模型的行为状态(确保层的行为是确定性的),而 torch.no_grad() 针对的是Autograd引擎的操作状态(是否追踪梯度)。理解这一本质区别至关重要。

B. 训练过程中的意外使用

如果在训练循环中本应计算梯度的部分(例如,计算损失并调用 loss.backward() 之前的模型前向传播)错误地被 torch.no_grad() 包裹,那么这些操作的梯度将不会被计算 。

# 错误的训练步骤示例:
# optimizer.zero_grad()
# with torch.no_grad(): # 错误!不应在计算训练损失时使用
#     outputs = model(train_inputs)
#     loss = criterion(outputs, train_labels)
# loss.backward() # 这可能会报错,或者梯度为None,无法有效更新参数
# optimizer.step()

这将导致模型权重无法正确更新,从而 фактически停止受影响部分的学习过程。在强化学习等领域,这种错误尤为常见,因为数据收集和策略评估的界限有时比较模糊 。

C. 与 tensor.detach() 的混淆

tensor.detach()torch.no_grad() 都与计算图和梯度有关,但它们的作用方式和范围不同。

  • tensor.detach():

    • 是一个应用于特定张量的方法 。
    • 它返回一个新的张量,这个新张量与原始张量共享底层数据存储,但从当前的计算图中被“分离”出来。这意味着通过这个新张量进行的操作不会影响原始计算图的梯度传播,其 requires_grad 属性为 False
    • 原始张量及其计算历史(如果存在其他引用)保持不变。
    • 常用于需要将某个中间结果用于非微分操作,或者希望阻止梯度流向网络特定部分等复杂训练场景 。例如,在某些生成对抗网络(GANs)或强化学习算法中,可能需要阻止梯度通过网络的某一部分。
    • 会全局禁用其后代码块中的梯度追踪。
    • .detach() 本身可以减少特定张量历史所占用的内存,因为它创建了一个不带历史的新视图 。
  • with torch.no_grad()::

    • 是一个上下文管理器(或装饰器),影响其作用域内所有的PyTorch操作 。
    • 它阻止为作用域内新执行的操作创建计算图历史。
  • 推理时显存节省的关键区别:对于整个推理过程的显存优化,torch.no_grad() 通常是更直接和全面的选择。因为它全局性地(在其作用域内)禁用了图的构建。如果试图用 detach() 达到同样的目的,理论上需要对每个中间输出都调用 .detach(),这既繁琐又不高效 。torch.no_grad() 通过从源头上阻止计算历史的累积,从而更有效地管理内存。

D. torch.no_grad() 并非万能的显存清理器

torch.no_grad() 的作用是预防性的:它阻止Autograd引擎追踪操作并为梯度计算存储的中间激活值 。

它并不能

  • 释放已经被现有张量占用的显存。
  • 清除PyTorch的CUDA缓存。

要显式释放不再需要的张量所占用的显存,你需要:

  1. 删除对这些张量的所有Python引用(例如,使用 del tensor_reference)。
  2. 可以调用 gc.collect() 来尝试触发Python的垃圾回收机制 。

要释放PyTorch缓存分配器持有的但当前未被张量使用的显存(使其在 nvidia-smi 等工具中显示为空闲),应使用 torch.cuda.empty_cache()

因此,torch.no_grad() 是关于阻止未来为计算图历史分配显存,而不是清理过去的分配。

E. 工厂函数例外

官方文档明确指出了一个例外情况:“所有工厂函数,或创建新张量并接受 requires_grad 关键字参数的函数,将不受此模式影响” 。

这意味着,如果你在 no_grad 块内部通过这类函数(如 torch.nn.Parameter()torch.ones(..., requires_grad=True))显式创建一个 requires_grad=True 的张量,那么该张量确实会具有 requires_grad=True

import torch
import torch.nn as nn

with torch.no_grad():
    # torch.nn.Parameter 默认设置 requires_grad=True
    param = nn.Parameter(torch.randn(5))
    print(f"param.requires_grad inside no_grad: {param.requires_grad}") # 输出: True

    # torch.zeros 默认 requires_grad=False
    a = torch.zeros(5)
    print(f"a.requires_grad inside no_grad: {a.requires_grad}") # 输出: False

    # 在工厂函数中显式设置 requires_grad=True
    b = torch.ones(5, requires_grad=True)
    print(f"b.requires_grad inside no_grad: {b.requires_grad}") # 输出: True

    # 然而,对此类张量 (如 param 或 b) 在 no_grad 上下文内进行的后续操作,
    # 其输出的 requires_grad 仍将为 False
    output_from_param = param * 2
    print(f"output_from_param.requires_grad: {output_from_param.requires_grad}") # 输出: False
    output_from_b = b + 1
    print(f"output_from_b.requires_grad: {output_from_b.requires_grad}") # 输出: False

这种行为在实践中是合理的:例如,模型参数(nn.Parameter)即使在 no_grad 上下文中被访问或参与某种形式的“元操作”,其作为可训练参数的本质(即 requires_grad=True)不应改变。但任何利用这些参数进行的、发生在 no_grad 块内的计算(如一次前向传播),其结果不应再追踪梯度。

F. torch.no_grad() 不改变张量的值

torch.no_grad() 只影响梯度追踪和计算图的构建,它不改变在其上下文中执行的任何操作的数值结果 。假设 model.eval() 已被正确调用以确保层行为的一致性,那么在 no_grad 块内进行的前向传播所产生的输出张量的值,与在块外进行同样的前向传播所产生的值是相同的。

理解这些常见陷阱对于有效调试至关重要。例如,如果一个模型没有按预期学习,检查代码中是否存在错放的 no_grad 块是一个关键的调试步骤。如果显存没有如预期那样节省下来,那么理解 no_grad 的作用域、工厂函数例外以及它与显存清理操作的区别就非常重要。特别地,用户有时可能试图通过将网络的一部分包裹在 no_grad 中来实现训练时的“冻结”,但这通常是错误的做法。如果目标是冻结某些层的参数不参与训练更新,但允许梯度流经这些层的激活值到达更早的层,正确的做法是设置这些参数的 param.requires_grad = Falseno_grad 会阻止其作用域内操作的输出产生梯度历史,这会中断梯度向更早层的传播。

VII. 总结:使用 torch.no_grad() 

torch.no_grad() 是PyTorch中一个简洁而强大的工具,它通过在特定代码区域禁用梯度计算,为模型推理和评估带来了显著的显存节省和潜在的速度提升。

核心益处回顾

  • 显存优化:通过避免存储用于反向传播的中间激活值,大幅降低推理时的峰值显存占用。
  • 性能提升:减少了Autograd引擎的开销,可能加速计算过程。

关键实践

  • 在进行任何形式的模型评估、验证或纯推理时,务必将相关代码块包裹在 with torch.no_grad(): 上下文中,或对专用推理函数使用 @torch.no_grad() 装饰器。
  • 始终将 model.eval()torch.no_grad() 结合使用,以确保模型层(如Dropout、BatchNorm)行为正确,同时实现计算效率的最大化。
  • 理解其作用域限制:torch.no_grad() 的影响仅限于其直接包裹的代码块或函数。

你可能感兴趣的:(pytorch,人工智能,python,笔记,深度学习)