基于pytorch的Swin Transformer用于cifar10分类

        在上篇文章基于pytorch的Vision Transformer用于cifar10分类中我们了解VIT在视觉分类中的应用,在本篇文章中将介绍另一种transformer模型-Swin Transformer。

一、前言:

对比ViT与Swin Transformer

1.模型架构的不同

  • ViT:Vision Transformer(ViT)是最早将Transformer模型应用于图像分类的网络。它的核心思想是将图像分割成若干固定大小的小块(patches),然后将这些patches通过线性嵌入变成一维向量,并喂给Transformer进行处理。ViT的优点在于它能够充分捕捉全局上下文信息,因为Transformer的自注意力机制能够处理全局依赖,但这种架构也要求大量数据来避免过拟合,因为它缺乏局部信息的建模。

  • Swin Transformer:Swin Transformer是为了解决ViT在计算和效率上的问题提出的。它采用了滑动窗口的自注意力机制(Shifted Window),通过局部窗口内的自注意力计算来减少计算复杂度,并通过将窗口在不同层级上移动(shift),来更好地捕捉长距离的依赖关系。Swin Transformer还采用了层次化的架构,将图像分割成不同尺寸的patches,从而能够捕捉不同尺度的特征。

2.计算复杂度不同

  • ViT:由于ViT是全局自注意力机制,它的计算复杂度是与图像的尺寸平方成正比的,这对于高分辨率图像来说会非常昂贵。具体来说,ViT的自注意力机制计算复杂度为O(N2)O(N^2)O(N2),其中NNN是输入的tokens数量,这对于大规模图像分类来说会导致较高的计算开销。

  • Swin Transformer:Swin Transformer通过局部窗口计算自注意力,将复杂度降低为O(N⋅log⁡N)O(N \cdot \log N)O(N⋅logN),因此在处理大规模图像时更为高效。它将图像分割为较小的窗口,并通过层次化处理和窗口的滑动来减少计算量,并且保持较好的性能。

3.局部与全局信息的处理

  • ViT:ViT本身能够处理全局信息,适用于捕捉图像的全局上下文,但由于它没有直接建模局部区域的信息(即patches之间的关系),这可能导致在处理一些较为细致的局部特征时性能较差。

  • Swin Transformer:Swin Transformer则通过窗口化处理和滑动窗口机制,在不同层次捕捉局部信息,并通过窗口的移动和层次化结构能够有效地结合局部和全局信息。这使得Swin Transformer在不同尺度上有更好的特征学习能力。

4. 灵活性和扩展性

  • ViT:ViT是一种相对简单的模型结构,较为直观,但它的表现和训练数据量紧密相关。ViT在大数据集上表现优秀,但在数据较少的情况下,容易出现过拟合。

  • Swin Transformer:Swin Transformer通过层次化和局部窗口的设计,使得模型在训练数据量较小的情况下也能够保持较好的性能。同时,Swin Transformer具有较好的扩展性,能够有效处理不同分辨率和大小的输入图像。

5. 在视觉任务中的表现

  • ViT:在大数据集(如ImageNet-21k)上训练时,ViT表现出了非常强的能力,但对于小数据集,ViT可能无法充分发挥其优势。

  • Swin Transformer:Swin Transformer能够在各种视觉任务中展现出强劲的性能,包括目标检测、图像分割等。其高效的计算和较强的特征提取能力使得它在多种任务中表现出色。

使用Swin Transformer模型的好处

  • 计算效率更高:由于采用了局部窗口和滑动窗口机制,Swin Transformer在计算复杂度上相较于ViT大大减少,能够在处理大规模图像时更高效。

  • 更好的局部特征建模:Swin Transformer在每个层次上通过局部窗口捕捉局部特征,并通过滑动窗口机制使得模型能够在多个尺度上学习信息,从而能够更好地建模图像中的局部细节和全局依赖。

  • 灵活性强:Swin Transformer通过层次化架构和不同大小的窗口设计,使得它能够适应不同大小、不同分辨率的图像输入,具有较强的扩展性。

  • 更好的性能:Swin Transformer在多项视觉任务中表现出色,包括图像分类、目标检测、图像分割等任务,尤其在计算资源有限的情况下表现得尤为突出。

二、Swin Transformer的工作原理的详细分析:
 

1. 分层结构

Swin Transformer采用层次化结构(Hierarchical Architecture),通过分层的方式逐步缩小特征图的尺寸,并逐层提取更高层次的抽象特征。与ViT将整个图像直接分割成固定大小的patches不同,Swin Transformer在每一层中通过改变patch的尺寸来构建多尺度的特征表示。

  • 低层次:输入图像被划分为较小的patches,每个patch都包含较小的局部信息。

  • 高层次:随着网络深度的增加,patch的尺寸逐步增大,从而捕获更大范围的特征。

2. 滑动窗口自注意力(Shifted Window Attention)

传统的Transformer使用全局自注意力(Global Attention),但计算复杂度是O(N2)O(N^2)O(N2),其中N是tokens的数量,对于高分辨率图像来说计算量巨大。Swin Transformer通过引入局部窗口滑动窗口来降低计算复杂度。

  • 局部窗口:Swin Transformer将输入图像划分为固定大小的局部窗口(例如,7x7或8x8的patch)。在每个窗口内,采用自注意力机制计算该窗口内的tokens之间的相互关系。

  • 滑动窗口:为了克服窗口之间信息隔离的问题,Swin Transformer采用了滑动窗口策略(Shifted Window)。即在每一层,窗口的划分位置会略微偏移(shift),从而在不同层次上能够捕捉到窗口之间的交互和全局依赖。

这种滑动窗口机制通过窗口的平移,有效地减少了计算复杂度,并允许信息跨窗口进行交流,避免了信息孤岛现象。

3. 局部和全局特征的结合

Swin Transformer通过分层和滑动窗口的设计,使得它能够同时处理局部信息和全局信息:

  • 在较低层次,局部窗口能够专注于细节特征,捕捉局部模式(如边缘、纹理等)。

  • 在较高层次,通过滑动窗口和平移机制,跨窗口的信息能够融合,形成全局上下文感知能力,使模型能够捕捉图像中的全局依赖。

4. Patch Merging 和 Downsampling

为了使得模型能够处理更大范围的上下文信息,Swin Transformer采用了Patch Merging操作来进行下采样。通过合并多个小patch,生成更大的patch。这一操作不仅减少了特征图的大小,还提高了表示能力。在网络的每一层,Swin Transformer通过这种方式减少计算量,同时保留了足够的特征信息。

  • Patch Merging:在每一层的输出中,多个相邻的patch会合并成一个较大的patch,并通过一个线性映射来处理。

5. Transformer Encoder

Swin Transformer在每一层的操作与标准Transformer类似,包括:

  • 多头自注意力(Multi-Head Attention):对于每个局部窗口,Swin Transformer通过多头自注意力来捕捉每个patch之间的相互依赖关系。自注意力机制计算每个位置(token)对其他位置(token)的关注权重,从而能够聚合全局信息。

  • 前馈神经网络(Feedforward Network):每个自注意力层后接一个前馈神经网络,用于进一步转换特征。

6. Transformer Layer Normalization 和残差连接

与传统的Transformer类似,Swin Transformer也采用了层归一化(Layer Normalization)和残差连接(Residual Connection)来避免梯度消失问题,同时加速收敛。

  • 残差连接:每个Transformer模块内部都包含了残差连接,这使得信息能够直接跨越多个层传递,缓解了深层网络训练中的难题。

7. 输出

通过多层堆叠的Swin Transformer模块,最终会得到一个层次化的特征表示。输出可以用于不同的任务,如图像分类、目标检测或图像分割等。对于图像分类任务,最后会将特征传递到一个分类头进行预测。

总结

Swin Transformer的工作原理结合了Transformer的全局建模能力和局部感知能力。通过分层结构、局部窗口自注意力、滑动窗口机制和Patch Merging,它能够高效地捕捉图像的局部和全局信息,同时保持计算效率。这使得Swin Transformer成为一个非常适合处理大规模图像任务的深度学习模型,尤其在视觉任务中具有优异的性能。

三、代码实现

1.库函数导入

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import accuracy_score

2.模型搭建

# 基础窗口注意力模块
class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads):
        super(WindowAttention, self).__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.softmax = nn.Softmax(dim=-1)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape  # (B, window_size*window_size, C)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.head_dim ** 0.5))
        attn = self.softmax(attn)

        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)
        return out

# Swin Transformer Block
class SwinBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=4, shift_size=0, mlp_ratio=4.0):
        super(SwinBlock, self).__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.window_size = window_size
        self.shift_size = shift_size
        self.num_heads = num_heads

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, window_size=window_size, num_heads=num_heads)
        self.norm2 = nn.LayerNorm(dim)

        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim)
        )

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "Input feature has wrong size"
        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        if self.shift_size > 0:
            x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))

        # 分块
        x_windows = x.unfold(1, self.window_size, self.window_size).unfold(2, self.window_size, self.window_size)
        x_windows = x_windows.contiguous().view(-1, self.window_size * self.window_size, C)

        attn_windows = self.attn(x_windows)
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)

        # 合并窗口
        num_windows = H // self.window_size
        x = attn_windows.view(B, num_windows, num_windows, self.window_size, self.window_size, C)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, C)

        if self.shift_size > 0:
            x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))

        x = x.view(B, H * W, C)
        x = shortcut + x

        # FFN
        x = x + self.mlp(self.norm2(x))
        return x

# Patch Embedding
class PatchEmbedding(nn.Module):
    def __init__(self, patch_size=4, in_chans=3, embed_dim=96):
        super(PatchEmbedding, self).__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # [B, embed_dim, H/patch, W/patch]
        x = x.flatten(2).transpose(1, 2)  # [B, N, embed_dim]
        return x

# Swin Transformer 主体模型
class SwinTransformer(nn.Module):
    def __init__(self, image_size=32, patch_size=4, in_chans=3, num_classes=10, embed_dim=96, depths=[2, 2], num_heads=[3, 6]):
        super(SwinTransformer, self).__init__()
        self.patch_embed = PatchEmbedding(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        self.pos_drop = nn.Dropout(p=0.0)

        self.layers = nn.ModuleList()
        H = W = image_size // patch_size
        for i_layer in range(len(depths)):
            layer = nn.Sequential(
                *[SwinBlock(embed_dim * (2 ** i_layer), (H // (2 ** i_layer), W // (2 ** i_layer)),
                            num_heads[i_layer], window_size=4, shift_size=(0 if j % 2 == 0 else 2))
                  for j in range(depths[i_layer])]
            )
            self.layers.append(layer)
            if i_layer < len(depths) - 1:
                self.layers.append(PatchMerging(H // (2 ** i_layer), W // (2 ** i_layer), embed_dim * (2 ** i_layer)))

        self.norm = nn.LayerNorm(embed_dim * (2 ** (len(depths) - 1)))
        self.head = nn.Linear(embed_dim * (2 ** (len(depths) - 1)), num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.pos_drop(x)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        x = x.mean(dim=1)
        x = self.head(x)
        return x

# Patch Merging 模块(降维)
class PatchMerging(nn.Module):
    def __init__(self, H, W, dim):
        super(PatchMerging, self).__init__()
        self.reduction = nn.Linear(4 * dim, 2 * dim)
        self.norm = nn.LayerNorm(4 * dim)
        self.H = H
        self.W = W

    def forward(self, x):
        B, L, C = x.shape
        x = x.view(B, self.H, self.W, C)

        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = torch.cat([x0, x1, x2, x3], -1)
        x = x.view(B, -1, 4 * C)
        x = self.norm(x)
        x = self.reduction(x)
        return x

1. Patch Embedding(图像切块并映射到高维特征)

  • 将输入图像分成固定大小的 patch(例如 4x4),每个 patch 映射为一个 token。

  • 使用的是一个 Conv2d(kernel=patch_size, stride=patch_size) 来完成切块和降采样。

  • CIFAR-10 图像为 32x32x3,经过 patch_size=4 后变成 8x8=64 个 patch,每个 patch 被映射为 embed_dim=96 的向量。

2. Swin Transformer Block(滑窗注意力 + MLP)

  • 模拟局部注意力,不在全图上做 Attention,而是在小窗口(如 4x4)内进行注意力。

  • Shifted Window 技术实现跨窗口信息交互:

    • 第一层使用常规窗口划分;

    • 第二层将窗口移动一半(shift_size=2),使得窗口覆盖区域错开,带来跨窗口的信息流动。

  • 每个 Block = LayerNorm -> Window Attention -> Add -> LayerNorm -> MLP -> Add

3. Patch Merging(类似 CNN 的下采样)

  • 2x2 个相邻 patch 合并成一个,通道翻倍,空间尺寸减半。

  • 有点类似 CNN 中的 Pooling 或 Strided Conv。

  • 构建层次结构:高层结构感受野更大,抽象能力更强。

4. 分类头(全局平均池化 + 全连接)

  • 最后通过 mean pooling(对所有 token 取平均),得到一个全局特征;

  • 接一个 Linear 层用于分类

模块名 功能描述 输出形状(以 CIFAR-10 为例)
PatchEmbedding 将 32×32×3 图像切为 8×8 patch,嵌入 96 维 [B, 64, 96]
SwinBlock ×2 window attention(含 shift),提取局部特征 [B, 64, 96]
PatchMerging 2×2 合并为 1,通道翻倍为 192,大小变为 4×4 [B, 16, 192]
SwinBlock ×2 高层注意力提取跨窗口抽象特征 [B, 16, 192]
LayerNorm + Avg 归一化后取平均,得到一个全局 token [B, 192]
Linear Classifier 映射为分类结果 [B, 10]

3.数据加载

# 数据集加载与预处理
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

4.初始换模型

# 初始化模型
model = SwinTransformer()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 损失与优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

5.训练函数

# 训练函数
def train(model, train_loader, criterion, optimizer, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        total = correct = loss_sum = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_sum += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f'Epoch {epoch+1}, Loss: {loss_sum/len(train_loader):.4f}, Acc: {correct/total:.4f}')

6.测试函数

def test(model, test_loader):
    model.eval()
    total = correct = 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    print(f'Test Accuracy: {acc:.4f}')

7.训练与测试


# 训练和测试
train(model, train_loader, criterion, optimizer, num_epochs=40)
test(model, test_loader)

你可能感兴趣的:(pytorch,transformer,人工智能)