对于Transformer中的mask理解

概述

transformer中的mask操作,可以分成encoder端和decoder端

  1. Enocoder中的mask
    对于Transformer中的mask理解_第1张图片
    首先确定mask的形状是 batch_size*seq_length
    对于小于最大长度的句子进行补0操作,对于大于最大长度的句子进行截断操作。
    对于Transformer中的mask理解_第2张图片
    虽然对少于最大长度的句子进行了补零操作但是这些0仍然会参与注意力分数的计算。
    对于Transformer中的mask理解_第3张图片
    这里需要将mask中的0变成负无穷,1变成0,与计算的注意力矩阵相加,原来有单词的注意力不变,没有单词的位置变换成负无穷,之后在进行softmax运算
  2. decoder端的mask
    对于Transformer中的mask理解_第4张图片
    decoder的mask形状是一个下三角矩阵,解码器在翻译单词时只能看到前面已经翻译的单词,不能看到后面的答案,所以使用一个下三角矩阵进行遮盖,每一次解码只给解码器看前面的单词。

你可能感兴趣的:(nlp,自然语言处理)