【PyTorch】torch.nn.functional.log_softmax() 函数:计算 log(softmax),用于多分类任务

torch.nn.functional.log_softmax

torch.nn.functional.log_softmaxPyTorch 提供的用于计算 log(softmax) 的函数,通常用于 多分类任务计算交叉熵损失,可以提高数值稳定性并防止数值溢出。


1. log_softmax 的数学公式

对于 输入张量 X X X,softmax 计算如下:
softmax ( X i ) = e X i ∑ j e X j \text{softmax}(X_i) = \frac{e^{X_i}}{\sum_{j} e^{X_j}} softmax(Xi)=jeXjeXi
然后取对数:
log ⁡ softmax ( X i ) = X i − log ⁡ ∑ j e X j \log \text{softmax}(X_i) = X_i - \log \sum_{j} e^{X_j} logsoftmax(Xi)=XilogjeXj

为什么使用 log_softmax 而不是 softmax + log

  1. 防止数值溢出
    • softmax(X) 可能会导致指数运算 溢出(特别是 X 取值较大时)。
    • log_softmax(X) 计算时,先 进行数值归一化,不会导致溢出。
  2. 更高效
    • log_softmax 计算速度更快,因为它可以与 nll_loss 直接配合使用。

2. torch.nn.functional.log_softmax 语法

torch.nn.functional.log_softmax(input, dim)
参数 说明
input 输入张量(通常是模型 logits 输出)
dim 计算 softmax 的维度(通常 dim=1,沿类别维度计算)

3. 示例:计算 log_softmax

import torch
import torch.nn.functional as F

# 假设 batch_size=2,类别数=3
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.0, 1.5]])

# 计算 log_softmax(dim=1 表示沿类别维度)
log_probs = F.log_softmax(logits, dim=1)

print(log_probs)

输出

tensor([[-0.4170, -1.4170, -2.3170],
        [-2.0076, -0.5076, -1.0076]])

解析

  • logits 未归一化,直接输入。
  • log_softmax(logits, dim=1) 对每一行计算 log(softmax)
  • 数值稳定性更高,避免 softmax 溢出问题

4. log_softmaxsoftmax + log 的区别

logits = torch.tensor([[2.0, 1.0, 0.1]])

# 方式 1:直接使用 log_softmax
log_probs1 = F.log_softmax(logits, dim=1)

# 方式 2:先 softmax 再 log
softmax_probs = torch.softmax(logits, dim=1)
log_probs2 = torch.log(softmax_probs)

print(log_probs1)
print(log_probs2)

两者输出相同,但 log_softmax 更稳定!

tensor([[-0.4170, -1.4170, -2.3170]])
tensor([[-0.4170, -1.4170, -2.3170]])
  • F.log_softmax torch.softmax + torch.log 更稳定,计算更高效

5. log_softmax 在交叉熵损失中的作用

PyTorch F.cross_entropy 内部已经包含 log_softmax

import torch.nn.functional as F

# 假设 batch_size=2,类别数=3
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.0, 1.5]])
targets = torch.tensor([0, 2])  # 真实类别索引

# 交叉熵损失(内部自动使用 log_softmax)
loss = F.cross_entropy(logits, targets)
print(loss)

解析

  • F.cross_entropy(logits, targets) 内部使用 log_softmax + nll_loss
  • 直接传入 logits,无需手动计算 softmax

6. log_softmaxnll_loss

如果已经计算了 log_softmax,可以直接使用 nll_loss

log_probs = F.log_softmax(logits, dim=1)

# 使用 nll_loss 计算损失
loss = F.nll_loss(log_probs, targets)

print(loss)

nll_loss 只计算 log_softmax 结果与目标的匹配度,与 F.cross_entropy 结果相同。


7. 适用场景

  • 分类任务(如 CNN, NLP)
    • F.cross_entropy 内部已使用 log_softmax,无需手动计算。
  • 强化学习(RL)
    • 计算 策略梯度损失 时,经常需要 log_softmax
  • 变分自编码器(VAE)
    • 计算 ELBO 损失 需要 log_softmax

8. 结论

  • torch.nn.functional.log_softmax 用于计算 log(softmax),提高数值稳定性
  • softmax + log 更稳定,更高效
  • 分类任务 中,cross_entropy 已包含 log_softmax,无需额外计算
  • 在强化学习、VAE 等任务中,log_softmax 也是常用的概率计算方法

你可能感兴趣的:(PyTorch基础,log_softmax,多分类,交叉熵损失,分类,pytorch,python,深度学习)