如果你对人工智能(AI)或深度学习(Deep Learning)感兴趣,可能听说过“Transformer”这个词。它最初在自然语言处理(NLP)领域大放异彩,比如在翻译、聊天机器人和文本生成中表现出色。但你知道吗?Transformer不仅能处理文字,还能用来分类图像!这听起来是不是有点神奇?别担心,这篇博客将带你从零开始,了解Transformer的基本概念、它如何被应用到图像分类,以及通过一个简单的例子让你直观理解它的运作原理。无论你是AI新手还是好奇的技术爱好者,这篇文章都会尽量用通俗的语言为你解锁Transformer的奥秘。
Transformer是一种深度学习模型,最早由Vaswani等人在2017年的论文《Attention is All You Need》中提出。它的核心思想是“注意力机制”(Attention Mechanism),这是一种让模型学会“关注”输入中重要部分的能力。传统的模型,比如卷积神经网络(CNN)和循环神经网络(RNN),在处理图像或序列数据时有局限性,而Transformer通过注意力机制突破了这些限制。
“Transformer”这个名字听起来很酷,但它其实反映了模型的功能:它能将输入数据“转换”(Transform)成更有意义的表示形式。比如,把一句话翻译成另一种语言,或者把一张图片“翻译”成一个分类标签(比如“猫”或“狗”)。它的核心在于通过计算输入数据之间的关系,生成更有用的输出。
Transformer由两个主要部分组成:编码器(Encoder)和解码器(Decoder)。不过,在图像分类任务中,我们通常只用到编码器部分。让我们简单看看它的组成:
这些组件堆叠在一起,形成多层结构,每一层都让模型对数据的理解更深一层。
注意力机制是Transformer的核心。想象你在读一本书,当你看到“猫”这个词时,你会自动想到整句话的上下文,比如“猫在睡觉”还是“猫在跑”。注意力机制让模型也能做到这一点:它会计算输入中每个部分对其他部分的“重要性”,然后根据这些关系调整输出。
具体来说,Transformer使用的是“自注意力”(Self-Attention)。它会为输入的每个部分(比如图像的一个小块)生成三个向量:
通过计算查询和键之间的相似度,模型决定每个值的权重,然后把它们加权组合起来。这种方式让Transformer能捕捉全局关系,而不是像CNN那样只关注局部区域。
Transformer最初是为NLP设计的,那它是怎么“跨界”到图像分类的呢?这要归功于2020年提出的Vision Transformer(简称ViT)。让我们看看它是如何工作的。
图像和文字完全不同,对吧?图像是一堆像素,而文字是一串单词。要让Transformer处理图像,第一步就是把图像“翻译”成它能理解的形式。ViT的做法是:
经过这些步骤,一张图像就变成了一个序列(Sequence),就像NLP中的一句话,只不过这里的“单词”是图像块。
一旦图像被转换成序列,Transformer的编码器就开始工作:
相比传统的CNN,ViT有几个优点:
但它也有挑战:
为了让新手更容易理解,我们通过一个具体的例子来说明Transformer如何进行图像分类。假设我们要训练一个模型,区分CIFAR-10数据集中的“猫”和“狗”图片(CIFAR-10是PyTorch内置的一个小型图像数据集,包含10类32x32像素的图像)。下面我们逐步拆解过程,并新增代码实现。
CIFAR-10中的每张图片是32x32像素,RGB格式。我们将它切成4x4的小块(为了简化示例),总共有64个块(32 ÷ 4 = 8,8x8 = 64)。每个小块有48个数值(4x4x3,因为RGB有3个通道)。
现在,这张图片变成了一个64x32的矩阵,就像一个有64个“单词”的序列。
假设猫咪的耳朵在第10个块,眼睛在第20个块。Transformer会:
经过多层自注意力,模型学会关联猫的特征。
在最后一层,ViT取一个特殊的“分类标记”(CLS Token),通过全连接层输出10个类别的概率(CIFAR-10有10类),比如“猫”的概率是0.8,“狗”是0.1。
下面我们提供两种代码实现方式,帮助你直观感受ViT的运作。代码基于PyTorch,使用CIFAR-10数据集。
这个实现简化了ViT的核心组件,适合理解原理。
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# 超参数
patch_size = 4 # 切分图像为4x4的小块
embed_dim = 32 # 每个小块的嵌入维度
num_heads = 4 # 注意力头的数量
num_classes = 10 # CIFAR-10有10个类别
num_patches = (32 // patch_size) ** 2 # 64个小块 (32x32图像)
# 数据加载
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
# 简化的ViT模型
class SimpleViT(nn.Module):
def __init__(self):
super(SimpleViT, self).__init__()
# 将图像块映射到嵌入空间
self.patch_to_embedding = nn.Linear(patch_size * patch_size * 3, embed_dim)
# 位置编码
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))
# CLS Token
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
# Transformer编码器
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads), num_layers=2)
# 分类头
self.fc = nn.Linear(embed_dim, num_classes)
def forward(self, x):
b, c, h, w = x.shape # [batch_size, 3, 32, 32]
# 切分成小块并展平
x = x.view(b, c, h // patch_size, patch_size, w // patch_size, patch_size)
x = x.permute(0, 2, 4, 1, 3, 5).contiguous() # [b, 8, 8, 3, 4, 4]
x = x.view(b, num_patches, -1) # [b, 64, 48]
# 映射到嵌入空间
x = self.patch_to_embedding(x) # [b, 64, 32]
# 添加CLS Token
cls_tokens = self.cls_token.expand(b, -1, -1) # [b, 1, 32]
x = torch.cat((cls_tokens, x), dim=1) # [b, 65, 32]
# 加上位置编码
x = x + self.pos_embedding
# 通过Transformer
x = self.transformer(x) # [b, 65, 32]
# 取CLS Token的输出进行分类
x = self.fc(x[:, 0]) # [b, 10]
return x
# 训练模型
model = SimpleViT()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(5): # 训练5个epoch
for images, labels in trainloader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
代码解释:
这个实现利用Hugging Face的预训练ViT模型,适合快速上手。
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 数据加载
transform = transforms.Compose([
transforms.Resize((224, 224)), # ViT需要224x224输入
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=16, shuffle=True)
# 加载预训练ViT模型和特征提取器
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 10) # 修改分类头为10类
# 训练设置
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# 训练模型
model.train()
for epoch in range(3): # 训练3个epoch
for images, labels in trainloader:
inputs = feature_extractor(images=[img.permute(1, 2, 0).numpy() for img in images], return_tensors="pt")
inputs = {k: v for k, v in inputs.items()} # 转换为模型输入格式
optimizer.zero_grad()
outputs = model(**inputs).logits # 获取分类输出
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
代码解释:
注意:运行第二种方式需要安装transformers库(pip install transformers)。
CNN像一个放大镜,逐步扫描图像的局部特征;而Transformer像一个全景相机,一次性捕捉全局关系。两者各有千秋,ViT证明了Transformer也能在图像任务中大放异彩。
好消息是,你不需要从头写Transformer!开源工具(如PyTorch和Hugging Face)提供了预训练模型。你只需要学会加载模型、准备数据和微调,就能上手。
不完全是。ViT在大数据集(如ImageNet)上表现很好,但在小数据集或需要精细局部特征的任务上,CNN可能更合适。
Transformer通过注意力机制和全局视野,为图像分类带来了新思路。Vision Transformer(ViT)展示了它如何将图像切分成块,像处理句子一样处理图片,最终实现分类。对于新手来说,理解它的关键在于:
通过上面的代码示例,你可以看到: