26 - UFO模块

论文《UFO-ViT: High Performance Linear Vision Transformer without Softmax》

1、作用

UFO-ViT旨在解决传统Transformer在视觉任务中所面临的主要挑战之一:SA机制的计算资源需求随输入尺寸的平方增长,这使得处理高分辨率输入变得不切实际。UFO-ViT通过提出一种新的SA机制,消除了非线性操作,实现了对计算复杂度的线性控制,同时保持了高性能。

2、机制

1、自注意力机制的简化

UFO-ViT通过消除softmax非线性,简化了自注意力机制。它采用简单的L2范数替代softmax函数,利用矩阵乘法的结合律先计算键和值的乘积,然后再与查询相乘。

2、跨标准化(XNorm)

UFO-ViT引入了一种新的规范化方法——跨标准化(XNorm),用于替换softmax。这种方法将键和值的自注意力乘积直接相乘,并通过线性核方法生成聚类,以线性复杂度处理自注意力。

3、独特优势

1、线性复杂度

与传统的具有O(N^2)复杂度的SA机制不同,UFO-ViT的SA机制具有线性复杂度,使其能够有效处理高分辨率输入。

2、适用于通用目的

UFO-ViT在图像分类和密集预测任务中的表现证明了其作为通用模型的潜力。尽管复杂度线性,但UFO-ViT模型在较低容量和FLOPS下仍然超越了大多数现有的基于Transformer的模型。

3、推理速度快,GPU内存需求低

UFO-ViT模型具有更快的推理速度,并且在各种分辨率下所需的GPU内存较少。特别是在处理高分辨率输入时,所需的计算资源并没有显著增加。

4、代码

import numpy as np
import torch
from torch import nn
from torch.functional import norm
from torch.nn import init

# 定义XNorm函数,对输入x进行规范化
def XNorm(x, gamma):
    norm_tensor = torch.norm(x, 2, -1, True)
    return x * gamma / norm_tensor

# UFOAttention类继承自nn.Module
class UFOAttention(nn.Module):
    '''
    实现一个改进的自注意力机制,具有线性复杂度。
    '''

    # 初始化函数
    def __init__(self, d_model, d_k, d_v, h, dropout=.1):
        '''
        :param d_model: 模型的维度
        :param d_k: 查询和键的维度
        :param d_v: 值的维度
        :param h: 注意力头数
        '''
        super(UFOAttention, self).__init__()
        # 初始化四个线性层:为查询、键、值和输出转换使用
        self.fc_q = nn.Linear(d_model, h * d_k)
        self.fc_k = nn.Linear(d_model, h * d_k)
        self.fc_v = nn.Linear(d_model, h * d_v)
        self.fc_o = nn.Linear(h * d_v, d_model)
        self.dropout = nn.Dropout(dropout)
        # gamma参数用于规范化
        self.gamma = nn.Parameter(torch.randn((1, h, 1, 1)))

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h

        self.init_weights()

    # 权重初始化
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    # 前向传播
    def forward(self, queries, keys, values):
        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]

        # 通过线性层将查询、键、值映射到新的空间
        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)

        # 计算键和值的乘积,然后对结果进行规范化
        kv = torch.matmul(k, v)  # bs,h,c,c
        kv_norm = XNorm(kv, self.gamma)  # bs,h,c,c
        q_norm = XNorm(q, self.gamma)  # bs,h,n,c
        out = torch.matmul(q_norm, kv_norm).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)
        out = self.fc_o(out)  # (b_s, nq, d_model)

        return out

if __name__ == '__main__':
    # 示例用法
    block = UFOAttention(d_model=512, d_k=512, d_v=512, h=8).cuda()
    input = torch.rand(64, 64, 512).cuda()
    output = block(input, input, input)
    print(output.shape)

你可能感兴趣的:(深度学习算法实现,人工智能,深度学习,pytorch,计算机视觉)