在Keras中实现Multi-head-attention

多头注意力机制其实本质上就是将多个注意力结果进行拼接后输出,目前有多种拼接的方法。
第一种:拼接后乘以一个可训练矩阵进行维度转换。
在Keras中实现Multi-head-attention_第1张图片例如有32维数据,则设置8个head,每个head有32维,则最后拼接结果为32x8=256维,再设置一权值矩阵为W0为(256x32),则最后结果为
【1x256】x【256x32】=1x32
第二种方法:将每一个头的维度缩小再对每个头的结果直接拼接为最后的输出维度。
在Keras中实现Multi-head-attention_第2张图片
例如有128维数据,则设置8个head,每个head有16维,则最后拼接结果为16x8=128维。
第三种方法:对每个头的结果进行求和后求平均,此方法多用于GAT。
在Keras中实现Multi-head-attention_第3张图片
本次我们应用第一种方法来做Multi-head-attention,Attention的实现代码在我另一篇篇文章中已经实现(https://blog.csdn.net/qq_41669355/article/details/121362089),同时也借鉴了苏神在自定义层中的一个类,以实现在自定义层中调用已有的层(https://spaces.ac.cn/archives/4765)。

class MAtt(OurLayer):
    def __init__(self, out_dim, **kwargs):
        super(MAtt, self).__init__(**kwargs)
        self.out_dim = out_dim

    def build(self, input_shape):
        super(MAtt, self).build(input_shape)
        self.head1 = MyAttention(out_dim=self.out_dim)
        self.head2 = MyAttention(out_dim=self.out_dim)
        self.head3 = MyAttention(out_dim=self.out_dim)
        self.head4 = MyAttention(out_dim=self.out_dim)
        self.w0 = Dense(self.out_dim, use_bias=False)
    def call(self, inputs):
        # input_size = tf.shape(inputs)
        h1 = self.reuse(self.head1,inputs)
        h2 = self.reuse(self.head2,inputs)
        h3 = self.reuse(self.head3,inputs)
        h4 = self.reuse(self.head4,inputs)
        # h_r=tf.reshape(tf.concat([h1,h2,h3,h4],-1),(input_size[0],input_size[1],1,input_size[-1]*4))
        # h_r=tf.reshape(tf.multiply(h_r,self.w0),(input_size[0],input_size[1],input_size[-1]))
        h_r=self.reuse(self.w0,tf.concat([h1, h2, h3, h4], -1))
        # h_r=average([h1,h2,h3,h4])
        return h_r
    def compute_output_shape(self, input_shape):
        return (input_shape[0],input_shape[1], self.out_dim)

这个代码目前还是有缺陷,比如不能自定义head数,目前固定为4,将该代码放置Keras的IMBD任务进行测试,结果如下:
在Keras中实现Multi-head-attention_第4张图片
最后的结果应该过拟合了,或许还有些参数没调好(比如key_size),但懒得调了,以后有需要的话再优化叭。

你可能感兴趣的:(神经网络,算法,矩阵,python)