torch.nn.functional.cross_entropy
torch.nn.functional.cross_entropy
是 PyTorch 中用于分类任务的交叉熵损失函数,用于衡量 预测概率分布与真实类别分布之间的差异,常用于 多分类任务(multi-class classification)。
对于 单个样本,交叉熵损失的计算公式为:
L = − ∑ i = 1 C y i log ( y i ^ ) \mathcal{L} = -\sum_{i=1}^{C} y_i \log (\hat{y_i}) L=−i=1∑Cyilog(yi^)
其中:
在 PyTorch 中,cross_entropy
直接接受未经过 Softmax 变换的 logits,并且内部会 自动计算 Softmax 并进行对数计算,提高数值稳定性。
torch.nn.functional.cross_entropy
的语法torch.nn.functional.cross_entropy(input, target, weight=None, ignore_index=-100, reduction='mean')
参数 | 说明 |
---|---|
input |
模型的 logits 输出(未经过 softmax) |
target |
真实类别索引(不是 one-hot 编码) |
weight |
每个类别的损失权重(用于类别不平衡问题) |
ignore_index |
忽略的类别索引(通常用于 padding ) |
reduction |
mean (默认):取均值,sum :取总和,none :逐样本损失 |
import torch
import torch.nn.functional as F
# 假设有 3 个类别
logits = torch.tensor([[2.0, 0.5, 1.0], [0.5, 2.0, 1.5]]) # 模型输出(未经过 Softmax)
targets = torch.tensor([0, 2]) # 真实类别索引
# 计算交叉熵损失
loss = F.cross_entropy(logits, targets)
print(loss)
解析
logits
形状为 (batch_size, num_classes)
。targets
形状为 (batch_size,)
,包含每个样本的真实类别索引。cross_entropy
vs nll_loss
PyTorch 提供 nll_loss
(负对数似然损失),但一般不直接使用,而是搭配 log_softmax
:
logits = torch.tensor([[2.0, 0.5, 1.0], [0.5, 2.0, 1.5]])
log_probs = torch.log_softmax(logits, dim=1) # 先计算 log_softmax
targets = torch.tensor([0, 2])
# 使用 nll_loss(需要 log softmax 作为输入)
loss_nll = F.nll_loss(log_probs, targets)
# cross_entropy = softmax + log + nll_loss
loss_ce = F.cross_entropy(logits, targets)
print(loss_nll, loss_ce) # 两者结果相同
结论
F.cross_entropy(logits, target)
= softmax + log + nll_loss
cross_entropy
计算更稳定,建议优先使用。如果类别不均衡,可以使用 weight
参数:
class_weights = torch.tensor([1.0, 2.0, 3.0]) # 权重:类别 0 最小,类别 2 最大
loss = F.cross_entropy(logits, targets, weight=class_weights)
print(loss)
作用
ignore_index
用于 padding
在 NLP 任务(如序列标注)中,可使用 ignore_index
忽略 padding
:
targets = torch.tensor([0, 2, -1]) # -1 代表 padding
loss = F.cross_entropy(logits, targets, ignore_index=-1)
print(loss)
作用
padding
位置的损失,适用于 RNN、Transformer 等 NLP 任务。reduction
参数reduction
控制损失计算方式:
loss_none = F.cross_entropy(logits, targets, reduction="none") # 返回每个样本的损失
loss_sum = F.cross_entropy(logits, targets, reduction="sum") # 总和
loss_mean = F.cross_entropy(logits, targets, reduction="mean") # 默认均值
print(loss_none, loss_sum, loss_mean)
作用
none
:返回逐样本损失,适用于 需要自定义损失计算 的任务。sum
:计算总损失。mean
:计算均值(默认)。padding
)weight
选项)torch.nn.functional.cross_entropy
是 PyTorch 中最常用的分类损失函数。softmax
。input
为 logits,target
为类别索引)。weight
处理类别不均衡,使用 ignore_index
处理 padding
。在 PyTorch 分类任务中,推荐使用 F.cross_entropy
作为标准损失函数。