Vision Transformer (ViT) 是计算机视觉领域的一个重大突破,它将自然语言处理中的 Transformer 架构应用到图像处理任务中。本文将带大家了解 ViT 的基本原理,并通过实践案例展示如何使用 ViT 进行图像分类任务。
长期以来,卷积神经网络 (CNN) 一直是计算机视觉的主流架构。2020年,Google 研究团队发表了论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》,提出了 Vision Transformer (ViT) 模型,将 NLP 领域的 Transformer 成功应用到图像识别中。
ViT 的核心思想非常直观:将图像视为一系列的"词"。具体来说:
ViT 的处理流程可以分为以下几个步骤:
这就像 CNN 需要一点一点地拼图,而 ViT 可以同时看到所有拼图块并理解它们如何组合。
在图像分类任务中,常用的评价指标包括:
下面我们将使用 ViT 模型在 CIFAR-10 数据集上进行图像分类。CIFAR-10 包含 10 个类别的 60,000 张 32×32 彩色图像,每个类别 6,000 张。
首先,我们需要安装必要的库:
# 安装必要的库
!pip install torch torchvision timm matplotlib
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import timm
from torch import nn, optim
from torch.utils.data import DataLoader
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
ViT 通常需要 224×224 像素的输入图像,而 CIFAR-10 的图像为 32×32 像素,因此需要调整大小:
# 数据预处理
transform = transforms.Compose([
transforms.Resize(224), # ViT 需要 224x224 输入
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载 CIFAR-10 数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = DataLoader(testset, batch_size=32, shuffle=False)
# CIFAR-10 的类别
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
# 可视化一些图像
def imshow(img):
img = img / 2 + 0.5 # 反归一化
npimg = img.numpy()
plt.figure(figsize=(10, 4))
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.axis('off')
plt.show()
# 获取随机训练图像
dataiter = iter(trainloader)
images, labels = next(dataiter)
# 显示图像
imshow(torchvision.utils.make_grid(images[:5]))
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(5)))
我们使用 timm 库加载预训练的 ViT 模型:
# 创建 ViT 模型
model = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=10)
model = model.to(device)
# 查看模型结构
print(f"模型参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
# 训练函数
def train_one_epoch(model, trainloader, criterion, optimizer):
model.train()
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in trainloader:
inputs, labels = inputs.to(device), labels.to(device)
# 清零梯度
optimizer.zero_grad()
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
loss.backward()
optimizer.step()
# 统计
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
# 返回平均损失和准确率
return running_loss / len(trainloader), 100. * correct / total
# 评估函数
def evaluate(model, testloader):
model.eval()
correct = 0
total = 0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for inputs, labels in testloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
# 整体准确率
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
# 每个类别的准确率
c = predicted.eq(labels).squeeze()
for i in range(labels.size(0)):
label = labels[i].item()
class_correct[label] += c[i].item()
class_total[label] += 1
# 打印每个类别的准确率
for i in range(10):
print(f'准确率 of {classes[i]}: {100 * class_correct[i] / class_total[i]:.1f}%')
return 100. * correct / total
由于完整训练需要较长时间,这里我们只训练几个 epoch 作为演示:
# 训练模型
num_epochs = 3
train_acc_history = []
test_acc_history = []
for epoch in range(num_epochs):
# 训练一个 epoch
train_loss, train_acc = train_one_epoch(model, trainloader, criterion, optimizer)
train_acc_history.append(train_acc)
# 评估
test_acc = evaluate(model, testloader)
test_acc_history.append(test_acc)
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, '
f'Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%')
# 可视化训练过程
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs+1), train_acc_history, 'b-', label='Train Accuracy')
plt.plot(range(1, num_epochs+1), test_acc_history, 'r-', label='Test Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.title('Training and Testing Accuracy')
plt.legend()
plt.grid(True)
plt.show()
# 可视化一些预测结果
def visualize_predictions(model, testloader, classes, num_images=5):
model.eval()
dataiter = iter(testloader)
images, labels = next(dataiter)
images = images[:num_images].to(device)
labels = labels[:num_images]
with torch.no_grad():
outputs = model(images)
_, predicted = outputs.max(1)
# 显示图像和预测结果
images = images.cpu()
predicted = predicted.cpu()
fig = plt.figure(figsize=(12, 3))
for i in range(num_images):
ax = fig.add_subplot(1, num_images, i + 1)
imshow(images[i] / 2 + 0.5)
color = 'green' if predicted[i] == labels[i] else 'red'
ax.set_title(f'预测: {classes[predicted[i]]}\n实际: {classes[labels[i]]}',
color=color)
plt.tight_layout()
plt.show()
# 可视化一些预测结果
visualize_predictions(model, testloader, classes)
# 保存模型
torch.save(model.state_dict(), 'vit_cifar10.pth')
print("模型已保存!")
自 ViT 提出以来,研究者们开发了多种改进版本:
在实践中,以下技巧可以提高 ViT 的训练效果:
# 数据增强
transform_train = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 学习率调度
from torch.optim.lr_scheduler import CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
# 使用混合精度训练
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
# 训练循环中使用
for inputs, labels in trainloader:
inputs, labels = inputs.to(device), labels.to(device)
# 使用混合精度
with autocast():
outputs = model(inputs)
loss = criterion(outputs, labels)
# 缩放梯度并更新参数
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
# 更新学习率
scheduler.step()
这些技巧可以显著提高 ViT 的训练效率和模型性能:
除了使用 timm 库,我们也可以使用 Hugging Face 的 Transformers 库实现 ViT。下面是一个简化的实现:
# 简单可靠的ViT实现 - CIFAR-10分类
import torch
import torchvision
import torchvision.transforms as transforms
import timm
from torch import nn, optim
from torch.utils.data import DataLoader
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 数据预处理
transform = transforms.Compose([
transforms.Resize(224), # ViT需要224x224输入
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载CIFAR-10
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = DataLoader(testset, batch_size=32, shuffle=False)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
# 创建ViT模型
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)
model = model.to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
# 训练模型
num_epochs = 3
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0
for i, (inputs, labels) in enumerate(trainloader):
inputs, labels = inputs.to(device), labels.to(device)
# 清零梯度
optimizer.zero_grad()
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
loss.backward()
optimizer.step()
# 统计
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
if i % 50 == 49:
print(f'[{epoch+1}, {i+1}] loss: {running_loss/50:.3f} | '
f'acc: {100.*correct/total:.2f}%')
running_loss = 0.0
# 评估
model.eval()
test_correct = 0
test_total = 0
with torch.no_grad():
for inputs, labels in testloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
test_total += labels.size(0)
test_correct += predicted.eq(labels).sum().item()
print(f'Epoch {epoch+1} Test Accuracy: {100.*test_correct/test_total:.2f}%')
print('训练完成!')
# 保存模型
torch.save(model.state_dict(), 'vit_cifar10.pth')
这种方法使用 Hugging Face 的 Trainer API,使得训练过程更加简洁。
ViT 及其变体已在多个计算机视觉任务中取得了出色的表现:
如果遇到问题,欢迎在评论区留言或关注微信公众号【算法最TOP】获取帮助!