在上篇文章基于pytorch的Vision Transformer用于cifar10分类中我们了解VIT在视觉分类中的应用,在本篇文章中将介绍另一种transformer模型-Swin Transformer。
一、前言:
对比ViT与Swin Transformer
1.模型架构的不同
ViT:Vision Transformer(ViT)是最早将Transformer模型应用于图像分类的网络。它的核心思想是将图像分割成若干固定大小的小块(patches),然后将这些patches通过线性嵌入变成一维向量,并喂给Transformer进行处理。ViT的优点在于它能够充分捕捉全局上下文信息,因为Transformer的自注意力机制能够处理全局依赖,但这种架构也要求大量数据来避免过拟合,因为它缺乏局部信息的建模。
Swin Transformer:Swin Transformer是为了解决ViT在计算和效率上的问题提出的。它采用了滑动窗口的自注意力机制(Shifted Window),通过局部窗口内的自注意力计算来减少计算复杂度,并通过将窗口在不同层级上移动(shift),来更好地捕捉长距离的依赖关系。Swin Transformer还采用了层次化的架构,将图像分割成不同尺寸的patches,从而能够捕捉不同尺度的特征。
2.计算复杂度不同
ViT:由于ViT是全局自注意力机制,它的计算复杂度是与图像的尺寸平方成正比的,这对于高分辨率图像来说会非常昂贵。具体来说,ViT的自注意力机制计算复杂度为O(N2)O(N^2)O(N2),其中NNN是输入的tokens数量,这对于大规模图像分类来说会导致较高的计算开销。
Swin Transformer:Swin Transformer通过局部窗口计算自注意力,将复杂度降低为O(N⋅logN)O(N \cdot \log N)O(N⋅logN),因此在处理大规模图像时更为高效。它将图像分割为较小的窗口,并通过层次化处理和窗口的滑动来减少计算量,并且保持较好的性能。
3.局部与全局信息的处理
ViT:ViT本身能够处理全局信息,适用于捕捉图像的全局上下文,但由于它没有直接建模局部区域的信息(即patches之间的关系),这可能导致在处理一些较为细致的局部特征时性能较差。
Swin Transformer:Swin Transformer则通过窗口化处理和滑动窗口机制,在不同层次捕捉局部信息,并通过窗口的移动和层次化结构能够有效地结合局部和全局信息。这使得Swin Transformer在不同尺度上有更好的特征学习能力。
ViT:ViT是一种相对简单的模型结构,较为直观,但它的表现和训练数据量紧密相关。ViT在大数据集上表现优秀,但在数据较少的情况下,容易出现过拟合。
Swin Transformer:Swin Transformer通过层次化和局部窗口的设计,使得模型在训练数据量较小的情况下也能够保持较好的性能。同时,Swin Transformer具有较好的扩展性,能够有效处理不同分辨率和大小的输入图像。
ViT:在大数据集(如ImageNet-21k)上训练时,ViT表现出了非常强的能力,但对于小数据集,ViT可能无法充分发挥其优势。
Swin Transformer:Swin Transformer能够在各种视觉任务中展现出强劲的性能,包括目标检测、图像分割等。其高效的计算和较强的特征提取能力使得它在多种任务中表现出色。
计算效率更高:由于采用了局部窗口和滑动窗口机制,Swin Transformer在计算复杂度上相较于ViT大大减少,能够在处理大规模图像时更高效。
更好的局部特征建模:Swin Transformer在每个层次上通过局部窗口捕捉局部特征,并通过滑动窗口机制使得模型能够在多个尺度上学习信息,从而能够更好地建模图像中的局部细节和全局依赖。
灵活性强:Swin Transformer通过层次化架构和不同大小的窗口设计,使得它能够适应不同大小、不同分辨率的图像输入,具有较强的扩展性。
更好的性能:Swin Transformer在多项视觉任务中表现出色,包括图像分类、目标检测、图像分割等任务,尤其在计算资源有限的情况下表现得尤为突出。
二、Swin Transformer的工作原理的详细分析:
Swin Transformer采用层次化结构(Hierarchical Architecture),通过分层的方式逐步缩小特征图的尺寸,并逐层提取更高层次的抽象特征。与ViT将整个图像直接分割成固定大小的patches不同,Swin Transformer在每一层中通过改变patch的尺寸来构建多尺度的特征表示。
低层次:输入图像被划分为较小的patches,每个patch都包含较小的局部信息。
高层次:随着网络深度的增加,patch的尺寸逐步增大,从而捕获更大范围的特征。
传统的Transformer使用全局自注意力(Global Attention),但计算复杂度是O(N2)O(N^2)O(N2),其中N是tokens的数量,对于高分辨率图像来说计算量巨大。Swin Transformer通过引入局部窗口和滑动窗口来降低计算复杂度。
局部窗口:Swin Transformer将输入图像划分为固定大小的局部窗口(例如,7x7或8x8的patch)。在每个窗口内,采用自注意力机制计算该窗口内的tokens之间的相互关系。
滑动窗口:为了克服窗口之间信息隔离的问题,Swin Transformer采用了滑动窗口策略(Shifted Window)。即在每一层,窗口的划分位置会略微偏移(shift),从而在不同层次上能够捕捉到窗口之间的交互和全局依赖。
这种滑动窗口机制通过窗口的平移,有效地减少了计算复杂度,并允许信息跨窗口进行交流,避免了信息孤岛现象。
Swin Transformer通过分层和滑动窗口的设计,使得它能够同时处理局部信息和全局信息:
在较低层次,局部窗口能够专注于细节特征,捕捉局部模式(如边缘、纹理等)。
在较高层次,通过滑动窗口和平移机制,跨窗口的信息能够融合,形成全局上下文感知能力,使模型能够捕捉图像中的全局依赖。
为了使得模型能够处理更大范围的上下文信息,Swin Transformer采用了Patch Merging操作来进行下采样。通过合并多个小patch,生成更大的patch。这一操作不仅减少了特征图的大小,还提高了表示能力。在网络的每一层,Swin Transformer通过这种方式减少计算量,同时保留了足够的特征信息。
Patch Merging:在每一层的输出中,多个相邻的patch会合并成一个较大的patch,并通过一个线性映射来处理。
Swin Transformer在每一层的操作与标准Transformer类似,包括:
多头自注意力(Multi-Head Attention):对于每个局部窗口,Swin Transformer通过多头自注意力来捕捉每个patch之间的相互依赖关系。自注意力机制计算每个位置(token)对其他位置(token)的关注权重,从而能够聚合全局信息。
前馈神经网络(Feedforward Network):每个自注意力层后接一个前馈神经网络,用于进一步转换特征。
与传统的Transformer类似,Swin Transformer也采用了层归一化(Layer Normalization)和残差连接(Residual Connection)来避免梯度消失问题,同时加速收敛。
残差连接:每个Transformer模块内部都包含了残差连接,这使得信息能够直接跨越多个层传递,缓解了深层网络训练中的难题。
通过多层堆叠的Swin Transformer模块,最终会得到一个层次化的特征表示。输出可以用于不同的任务,如图像分类、目标检测或图像分割等。对于图像分类任务,最后会将特征传递到一个分类头进行预测。
Swin Transformer的工作原理结合了Transformer的全局建模能力和局部感知能力。通过分层结构、局部窗口自注意力、滑动窗口机制和Patch Merging,它能够高效地捕捉图像的局部和全局信息,同时保持计算效率。这使得Swin Transformer成为一个非常适合处理大规模图像任务的深度学习模型,尤其在视觉任务中具有优异的性能。
三、代码实现
1.库函数导入
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import accuracy_score
2.模型搭建
# 基础窗口注意力模块
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads):
super(WindowAttention, self).__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.softmax = nn.Softmax(dim=-1)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
B, N, C = x.shape # (B, window_size*window_size, C)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.head_dim ** 0.5))
attn = self.softmax(attn)
out = (attn @ v).transpose(1, 2).reshape(B, N, C)
out = self.proj(out)
return out
# Swin Transformer Block
class SwinBlock(nn.Module):
def __init__(self, dim, input_resolution, num_heads, window_size=4, shift_size=0, mlp_ratio=4.0):
super(SwinBlock, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.window_size = window_size
self.shift_size = shift_size
self.num_heads = num_heads
self.norm1 = nn.LayerNorm(dim)
self.attn = WindowAttention(dim, window_size=window_size, num_heads=num_heads)
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, dim)
)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "Input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
if self.shift_size > 0:
x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# 分块
x_windows = x.unfold(1, self.window_size, self.window_size).unfold(2, self.window_size, self.window_size)
x_windows = x_windows.contiguous().view(-1, self.window_size * self.window_size, C)
attn_windows = self.attn(x_windows)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# 合并窗口
num_windows = H // self.window_size
x = attn_windows.view(B, num_windows, num_windows, self.window_size, self.window_size, C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, C)
if self.shift_size > 0:
x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
x = x.view(B, H * W, C)
x = shortcut + x
# FFN
x = x + self.mlp(self.norm2(x))
return x
# Patch Embedding
class PatchEmbedding(nn.Module):
def __init__(self, patch_size=4, in_chans=3, embed_dim=96):
super(PatchEmbedding, self).__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x) # [B, embed_dim, H/patch, W/patch]
x = x.flatten(2).transpose(1, 2) # [B, N, embed_dim]
return x
# Swin Transformer 主体模型
class SwinTransformer(nn.Module):
def __init__(self, image_size=32, patch_size=4, in_chans=3, num_classes=10, embed_dim=96, depths=[2, 2], num_heads=[3, 6]):
super(SwinTransformer, self).__init__()
self.patch_embed = PatchEmbedding(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
self.pos_drop = nn.Dropout(p=0.0)
self.layers = nn.ModuleList()
H = W = image_size // patch_size
for i_layer in range(len(depths)):
layer = nn.Sequential(
*[SwinBlock(embed_dim * (2 ** i_layer), (H // (2 ** i_layer), W // (2 ** i_layer)),
num_heads[i_layer], window_size=4, shift_size=(0 if j % 2 == 0 else 2))
for j in range(depths[i_layer])]
)
self.layers.append(layer)
if i_layer < len(depths) - 1:
self.layers.append(PatchMerging(H // (2 ** i_layer), W // (2 ** i_layer), embed_dim * (2 ** i_layer)))
self.norm = nn.LayerNorm(embed_dim * (2 ** (len(depths) - 1)))
self.head = nn.Linear(embed_dim * (2 ** (len(depths) - 1)), num_classes)
def forward(self, x):
x = self.patch_embed(x)
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
x = x.mean(dim=1)
x = self.head(x)
return x
# Patch Merging 模块(降维)
class PatchMerging(nn.Module):
def __init__(self, H, W, dim):
super(PatchMerging, self).__init__()
self.reduction = nn.Linear(4 * dim, 2 * dim)
self.norm = nn.LayerNorm(4 * dim)
self.H = H
self.W = W
def forward(self, x):
B, L, C = x.shape
x = x.view(B, self.H, self.W, C)
x0 = x[:, 0::2, 0::2, :]
x1 = x[:, 1::2, 0::2, :]
x2 = x[:, 0::2, 1::2, :]
x3 = x[:, 1::2, 1::2, :]
x = torch.cat([x0, x1, x2, x3], -1)
x = x.view(B, -1, 4 * C)
x = self.norm(x)
x = self.reduction(x)
return x
将输入图像分成固定大小的 patch(例如 4x4
),每个 patch 映射为一个 token。
使用的是一个 Conv2d(kernel=patch_size, stride=patch_size)
来完成切块和降采样。
CIFAR-10 图像为 32x32x3
,经过 patch_size=4
后变成 8x8=64
个 patch,每个 patch 被映射为 embed_dim=96
的向量。
模拟局部注意力,不在全图上做 Attention,而是在小窗口(如 4x4
)内进行注意力。
Shifted Window 技术实现跨窗口信息交互:
第一层使用常规窗口划分;
第二层将窗口移动一半(shift_size=2),使得窗口覆盖区域错开,带来跨窗口的信息流动。
每个 Block = LayerNorm -> Window Attention -> Add -> LayerNorm -> MLP -> Add
将 2x2
个相邻 patch 合并成一个,通道翻倍,空间尺寸减半。
有点类似 CNN 中的 Pooling 或 Strided Conv。
构建层次结构:高层结构感受野更大,抽象能力更强。
最后通过 mean pooling
(对所有 token 取平均),得到一个全局特征;
接一个 Linear
层用于分类
模块名 | 功能描述 | 输出形状(以 CIFAR-10 为例) |
---|---|---|
PatchEmbedding | 将 32×32×3 图像切为 8×8 patch,嵌入 96 维 | [B, 64, 96] |
SwinBlock ×2 | window attention(含 shift),提取局部特征 | [B, 64, 96] |
PatchMerging | 2×2 合并为 1,通道翻倍为 192,大小变为 4×4 | [B, 16, 192] |
SwinBlock ×2 | 高层注意力提取跨窗口抽象特征 | [B, 16, 192] |
LayerNorm + Avg | 归一化后取平均,得到一个全局 token | [B, 192] |
Linear Classifier | 映射为分类结果 | [B, 10] |
3.数据加载
# 数据集加载与预处理
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
4.初始换模型
# 初始化模型
model = SwinTransformer()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 损失与优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
5.训练函数
# 训练函数
def train(model, train_loader, criterion, optimizer, num_epochs=10):
for epoch in range(num_epochs):
model.train()
total = correct = loss_sum = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_sum += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Epoch {epoch+1}, Loss: {loss_sum/len(train_loader):.4f}, Acc: {correct/total:.4f}')
6.测试函数
def test(model, test_loader):
model.eval()
total = correct = 0
all_preds, all_labels = [], []
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
acc = accuracy_score(all_labels, all_preds)
print(f'Test Accuracy: {acc:.4f}')
7.训练与测试
# 训练和测试
train(model, train_loader, criterion, optimizer, num_epochs=40)
test(model, test_loader)