Vision Transformer | Arxiv 2106 - CAT: Cross Attention in Vision Transformer

Arxiv 2106 - CAT: Cross Attention in Vision Transformer

Vision Transformer | Arxiv 2106 - CAT: Cross Attention in Vision Transformer_第1张图片

  • 论文:https://arxiv.org/abs/2106.05786
  • 代码:https://github.com/linhezheng19/CAT
  • 详细解读:https://mp.weixin.qq.com/s/VJCDAo94Uo_OtflSHRc1AQ
  • 核心动机:使用patch内部和patch之间attention简化了全局attention计算。

本文仅做核心模块的粗略说明,力求对本文工作核心差异的完整展示,具体细节可见参考上面的解读文章。

主要内容

Vision Transformer | Arxiv 2106 - CAT: Cross Attention in Vision Transformer_第2张图片
Vision Transformer | Arxiv 2106 - CAT: Cross Attention in Vision Transformer_第3张图片
Vision Transformer | Arxiv 2106 - CAT: Cross Attention in Vision Transformer_第4张图片

  • Cross Attention Block (CAB) = Inner-Patch Self-Attention Block (IPSA) + Cross-Patch Self-Attention Block (CPSA):
    • IPSA:就是标准的基于patch的attention,即attention的输入为B*nph*npw,ph*pw,C大小的tensor,得到的是空间大小为ph*pw,ph*pw的attention矩阵。该模块建模了patch内部的全局关系。
    • CPSA:这里处理的方式和以往的改进不太一样。这里attention计算的输入为B*C,nph*npw,ph*pw。对应的attention矩阵大小为nph*npw,nph*npw,这里计算过程中是吧每个patch内部单一通道上的空间维度作为了每个patch信息的表示,从而通过相似性计算将这一维度给吸收了。这一模块基于通道独立的操作设计,构建了全局patch之间轻量的信息交互形式。

核心代码

        x = x.view(B, H, W, C)

        # partition
        patches = partition(x, self.patch_size)  # nP*B, patch_size, patch_size, C
        patches = patches.view(-1, self.patch_size * self.patch_size, C)  # nP*B, patch_size*patch_size, C

        # IPSA or CPSA
        if self.attn_type == "ipsa":
            attn = self.attn(patches)  # nP*B, patch_size*patch_size, C
        elif self.attn_type == "cpsa":
            patches = patches.view(B, (H // self.patch_size) * (W // self.patch_size), self.patch_size ** 2, C).permute(0, 3, 1, 2).contiguous()
            patches = patches.view(-1, (H // self.patch_size) * (W // self.patch_size), self.patch_size ** 2) # nP*B*C, nP*nP, patch_size*patch_size
            attn = self.attn(patches).view(B, C, (H // self.patch_size) * (W // self.patch_size), self.patch_size ** 2)
            attn = attn.permute(0, 2, 3, 1).contiguous().view(-1, self.patch_size ** 2, C) # nP*B, patch_size*patch_size, C
        else :
            raise NotImplementedError(f"Unkown Attention type: {self.attn_type}") 

        # reverse opration of partition
        attn = attn.view(-1, self.patch_size, self.patch_size, C)
        x = reverse(attn, self.patch_size, H, W)  # B H' W' C

        x = x.view(B, H * W, C)

实验结果

Vision Transformer | Arxiv 2106 - CAT: Cross Attention in Vision Transformer_第5张图片

你可能感兴趣的:(注意力机制,#,深度网络论文学习,transformer,深度学习,自然语言处理,人工智能,计算机视觉)