torch.nn.functional.log_softmax
torch.nn.functional.log_softmax
是 PyTorch 提供的用于计算 log(softmax)
的函数,通常用于 多分类任务 和 计算交叉熵损失,可以提高数值稳定性并防止数值溢出。
log_softmax
的数学公式对于 输入张量 X X X,softmax 计算如下:
softmax ( X i ) = e X i ∑ j e X j \text{softmax}(X_i) = \frac{e^{X_i}}{\sum_{j} e^{X_j}} softmax(Xi)=∑jeXjeXi
然后取对数:
log softmax ( X i ) = X i − log ∑ j e X j \log \text{softmax}(X_i) = X_i - \log \sum_{j} e^{X_j} logsoftmax(Xi)=Xi−logj∑eXj
log_softmax
而不是 softmax + log
?softmax(X)
可能会导致指数运算 溢出(特别是 X
取值较大时)。log_softmax(X)
计算时,先 进行数值归一化,不会导致溢出。log_softmax
计算速度更快,因为它可以与 nll_loss
直接配合使用。torch.nn.functional.log_softmax
语法torch.nn.functional.log_softmax(input, dim)
参数 | 说明 |
---|---|
input |
输入张量(通常是模型 logits 输出) |
dim |
计算 softmax 的维度(通常 dim=1 ,沿类别维度计算) |
log_softmax
import torch
import torch.nn.functional as F
# 假设 batch_size=2,类别数=3
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.0, 1.5]])
# 计算 log_softmax(dim=1 表示沿类别维度)
log_probs = F.log_softmax(logits, dim=1)
print(log_probs)
输出
tensor([[-0.4170, -1.4170, -2.3170],
[-2.0076, -0.5076, -1.0076]])
解析
logits
未归一化,直接输入。log_softmax(logits, dim=1)
对每一行计算 log(softmax)。softmax
溢出问题。log_softmax
与 softmax + log
的区别logits = torch.tensor([[2.0, 1.0, 0.1]])
# 方式 1:直接使用 log_softmax
log_probs1 = F.log_softmax(logits, dim=1)
# 方式 2:先 softmax 再 log
softmax_probs = torch.softmax(logits, dim=1)
log_probs2 = torch.log(softmax_probs)
print(log_probs1)
print(log_probs2)
两者输出相同,但 log_softmax
更稳定!
tensor([[-0.4170, -1.4170, -2.3170]])
tensor([[-0.4170, -1.4170, -2.3170]])
F.log_softmax
比 torch.softmax + torch.log
更稳定,计算更高效。log_softmax
在交叉熵损失中的作用PyTorch F.cross_entropy
内部已经包含 log_softmax
:
import torch.nn.functional as F
# 假设 batch_size=2,类别数=3
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.0, 1.5]])
targets = torch.tensor([0, 2]) # 真实类别索引
# 交叉熵损失(内部自动使用 log_softmax)
loss = F.cross_entropy(logits, targets)
print(loss)
解析
F.cross_entropy(logits, targets)
内部使用 log_softmax + nll_loss
。logits
,无需手动计算 softmax
。log_softmax
与 nll_loss
如果已经计算了 log_softmax
,可以直接使用 nll_loss
:
log_probs = F.log_softmax(logits, dim=1)
# 使用 nll_loss 计算损失
loss = F.nll_loss(log_probs, targets)
print(loss)
nll_loss
只计算 log_softmax
结果与目标的匹配度,与 F.cross_entropy
结果相同。
F.cross_entropy
内部已使用 log_softmax
,无需手动计算。log_softmax
。log_softmax
。torch.nn.functional.log_softmax
用于计算 log(softmax)
,提高数值稳定性。softmax + log
更稳定,更高效。cross_entropy
已包含 log_softmax
,无需额外计算。log_softmax
也是常用的概率计算方法。