Seq2Seq - CrossEntropyLoss细节讨论

在 PyTorch 中,损失函数 CrossEntropyLoss 的输入参数通常需要满足特定的形状要求。对于 CrossEntropyLoss,输入参数的形状要求如下:

  1. input:模型的输出,形状为 [N, C],其中:

    • N 是样本数量(或展平后的序列长度)。

    • C 是类别数量(目标词汇表的大小)。

  2. target:目标标签,形状为 [N],其中每个元素是一个类别索引(整数)。

在上一节的代码中:

loss = loss_fn(similarities.view(-1, len(cn_vocab)), batch_labels.view(-1))

similaritiesbatch_labels 需要被调整为上述形状,以便符合 CrossEntropyLoss 的输入要求。以下详细解释为什么要这样写:

1. similarities.view(-1, len(cn_vocab))

  • similarities 的原始形状[batch_size, seq_len, output_dim],其中:

    • batch_size 是批次大小。

    • seq_len 是序列长度。

    • output_dim 是目标词汇表的大小(len(cn_vocab))。

  • 目标形状[N, C],其中:

    • N 是展平后的序列长度,即 batch_size * seq_len

    • C 是目标词汇表的大小,即 len(cn_vocab)

  • 操作

    • 使用 view(-1, len(cn_vocab))similarities 展平为二维张量,形状为 [batch_size * seq_len, len(cn_vocab)]

    • 这样,每个时间步的输出都被展平为一个二维张量,每一行表示一个时间步的预测概率分布。

2. batch_labels.view(-1)

  • batch_labels 的原始形状[batch_size, seq_len],其中:

    • batch_size 是批次大小。

    • seq_len 是序列长度。

  • 目标形状[N],其中:

    • N 是展平后的序列长度,即 batch_size * seq_len

  • 操作

    • 使用 view(-1)batch_labels 展平为一维张量,形状为 [batch_size * seq_len]

    • 这样,每个时间步的目标标签都被展平为一个一维张量,每个元素是一个类别索引。

3. 为什么这样写

  • 符合 CrossEntropyLoss 的输入要求

    • CrossEntropyLoss 要求输入的预测概率分布是一个二维张量 [N, C],其中每一行表示一个样本的预测概率分布。

    • 目标标签是一个一维张量 [N],其中每个元素是一个类别索引。

  • 处理序列数据

    • 在序列到序列的任务中,每个时间步都有一个预测和一个目标标签。

    • 通过展平操作,可以将所有时间步的预测和目标标签合并为一个批次,从而一次性计算整个批次的损失。

示例

假设:

  • batch_size = 2

  • seq_len = 3

  • output_dim = 5(目标词汇表大小)

原始数据:

similarities: [2, 3, 5]  # [batch_size, seq_len, output_dim]
batch_labels: [2, 3]     # [batch_size, seq_len]

经过 view 操作后:

similarities.view(-1, 5): [6, 5]  # [batch_size * seq_len, output_dim]
batch_labels.view(-1): [6]       # [batch_size * seq_len]

这样,similarities 的每一行表示一个时间步的预测概率分布,batch_labels 的每个元素是一个类别索引,完全符合 CrossEntropyLoss 的输入要求。

你可能感兴趣的:(人工智能,深度学习,seq2seq)