import torch
import torch.nn as nn
import torch.nn.functional as F
class PatchEmbedding(nn.Module):
def __init__(self, img_size, patch_size, in_chans, embed_dim):
super().__init__()
assert img_size % patch_size == 0, "Image dimensions must be divisible by the patch size."
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, x):
# x: [B, C, H, W]
x = self.proj(x) # [B, embed_dim, H/ps, W/ps]
x = x.flatten(2) # [B, embed_dim, num_patches]
x = x.transpose(1, 2) # [B, num_patches, embed_dim]
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x) # [B, N, 3*C]
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = qkv.permute(2, 0, 3, 1, 4)
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, heads, N, N]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MLP(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class TransformerEncoderBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias,
attn_drop=attn_drop, proj_drop=drop)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop)
self.drop_path = nn.Identity() # for simplicity, skip stochastic depth
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class SimpleViT(nn.Module):
def __init__(self,
img_size=32,
patch_size=4,
in_chans=3,
num_classes=10,
embed_dim=64,
depth=6,
num_heads=8,
mlp_ratio=4.,
qkv_bias=False,
drop_rate=0.,
attn_drop_rate=0.):
super().__init__()
# Patch embedding
self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
# Class token and position embeddings
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
# Transformer blocks
self.blocks = nn.Sequential(*[
TransformerEncoderBlock(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate
)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
# Classification head
self.head = nn.Linear(embed_dim, num_classes)
# Initialize weights
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_module)
def _init_module(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
def forward(self, x):
B = x.size(0)
x = self.patch_embed(x)
# prepend class token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
x = self.blocks(x)
x = self.norm(x)
# classification token
cls_token_final = x[:, 0]
x = self.head(cls_token_final)
return x
if __name__ == "__main__":
# Simple test
model = SimpleViT()
imgs = torch.randn(8, 3, 32, 32) # batch of CIFAR-like images
logits = model(imgs)
print(logits.shape) # [8, 10]
最简版的 ViT(Vision Transformer),包括:
Patch Embedding:将图像切分为固定大小的小块并线性映射。
多头自注意力(Attention):核心的 Transformer 机制,用于建模全局依赖。
MLP:每个 Transformer block 中的前馈网络。
Transformer Encoder Block:LayerNorm + Attention + MLP 的残差结构。
Class Token & 位置编码:用于分类任务的特殊 token 和位置嵌入。
分类头:将最终的 class token 投射到类别空间。
在 __main__
中进行简单的测试,8 张 32×32 彩色图像经过模型输出形状为 [8, 10]
的 logits。