26备战秋招day19——VIT

Vision Transformer (ViT):变革计算机视觉的新架构

Vision Transformer (ViT) 是计算机视觉领域的一个重大突破,它将自然语言处理中的 Transformer 架构应用到图像处理任务中。本文将带大家了解 ViT 的基本原理,并通过实践案例展示如何使用 ViT 进行图像分类任务。


一、ViT 模型介绍

1. 背景

长期以来,卷积神经网络 (CNN) 一直是计算机视觉的主流架构。2020年,Google 研究团队发表了论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》,提出了 Vision Transformer (ViT) 模型,将 NLP 领域的 Transformer 成功应用到图像识别中。

2. ViT 的核心思想

ViT 的核心思想非常直观:将图像视为一系列的"词"。具体来说:

  • 将图像分割成固定大小的小方块(例如 16×16 像素)
  • 将这些图像块视为序列中的"词元"(tokens)
  • 使用标准 Transformer 编码器处理这个序列
  • 添加一个特殊的分类标记 [CLS] 用于最终分类
3. ViT 的工作流程

ViT 的处理流程可以分为以下几个步骤:

  1. 图像分块:将输入图像(如 224×224 像素)分割成不重叠的小块(如 16×16 像素)
  2. 线性投影:将每个图像块通过线性层转换为固定维度的向量
  3. 位置编码:添加位置编码,因为 Transformer 本身不知道输入序列的顺序
  4. Transformer 编码器:使用多层自注意力机制和前馈网络处理序列
  5. 分类头:使用 [CLS] 标记的最终表示进行分类
4. ViT 与 CNN 的区别
  • CNN:通过局部感受野逐层提取特征,具有局部连接和权重共享特性
  • ViT:直接通过自注意力机制同时处理所有图像块,能够捕捉全局关系

这就像 CNN 需要一点一点地拼图,而 ViT 可以同时看到所有拼图块并理解它们如何组合。


二、评价指标介绍

在图像分类任务中,常用的评价指标包括:

  1. 准确率 (Accuracy):正确分类的样本占总样本的比例
  2. 精度 (Precision):预测为某类的样本中,实际属于该类的比例
  3. 召回率 (Recall):实际属于某类的样本中,被正确预测的比例
  4. F1 分数:精度和召回率的调和平均
  5. 混淆矩阵:直观展示各类别的预测情况

三、实践案例:基于 ViT 的图像分类

下面我们将使用 ViT 模型在 CIFAR-10 数据集上进行图像分类。CIFAR-10 包含 10 个类别的 60,000 张 32×32 彩色图像,每个类别 6,000 张。

1. 安装依赖

首先,我们需要安装必要的库:

# 安装必要的库
!pip install torch torchvision timm matplotlib
2. 导入模块和设置设备
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import timm
from torch import nn, optim
from torch.utils.data import DataLoader

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
3. 数据加载和预处理

ViT 通常需要 224×224 像素的输入图像,而 CIFAR-10 的图像为 32×32 像素,因此需要调整大小:

# 数据预处理
transform = transforms.Compose([
    transforms.Resize(224),  # ViT 需要 224x224 输入
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载 CIFAR-10 数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = DataLoader(testset, batch_size=32, shuffle=False)

# CIFAR-10 的类别
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')
4. 可视化一些样本图像
# 可视化一些图像
def imshow(img):
    img = img / 2 + 0.5  # 反归一化
    npimg = img.numpy()
    plt.figure(figsize=(10, 4))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')
    plt.show()

# 获取随机训练图像
dataiter = iter(trainloader)
images, labels = next(dataiter)

# 显示图像
imshow(torchvision.utils.make_grid(images[:5]))
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(5)))
5. 创建 ViT 模型

我们使用 timm 库加载预训练的 ViT 模型:

# 创建 ViT 模型
model = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=10)
model = model.to(device)

# 查看模型结构
print(f"模型参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
6. 定义训练和评估函数
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)

# 训练函数
def train_one_epoch(model, trainloader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # 清零梯度
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # 反向传播和优化
        loss.backward()
        optimizer.step()
        
        # 统计
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    # 返回平均损失和准确率
    return running_loss / len(trainloader), 100. * correct / total

# 评估函数
def evaluate(model, testloader):
    model.eval()
    correct = 0
    total = 0
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))
    
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            
            # 整体准确率
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # 每个类别的准确率
            c = predicted.eq(labels).squeeze()
            for i in range(labels.size(0)):
                label = labels[i].item()
                class_correct[label] += c[i].item()
                class_total[label] += 1
    
    # 打印每个类别的准确率
    for i in range(10):
        print(f'准确率 of {classes[i]}: {100 * class_correct[i] / class_total[i]:.1f}%')
    
    return 100. * correct / total
7. 训练模型

由于完整训练需要较长时间,这里我们只训练几个 epoch 作为演示:

# 训练模型
num_epochs = 3
train_acc_history = []
test_acc_history = []

for epoch in range(num_epochs):
    # 训练一个 epoch
    train_loss, train_acc = train_one_epoch(model, trainloader, criterion, optimizer)
    train_acc_history.append(train_acc)
    
    # 评估
    test_acc = evaluate(model, testloader)
    test_acc_history.append(test_acc)
    
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, '
          f'Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%')
8. 可视化训练过程
# 可视化训练过程
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs+1), train_acc_history, 'b-', label='Train Accuracy')
plt.plot(range(1, num_epochs+1), test_acc_history, 'r-', label='Test Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.title('Training and Testing Accuracy')
plt.legend()
plt.grid(True)
plt.show()
9. 可视化预测结果
# 可视化一些预测结果
def visualize_predictions(model, testloader, classes, num_images=5):
    model.eval()
    dataiter = iter(testloader)
    images, labels = next(dataiter)
    images = images[:num_images].to(device)
    labels = labels[:num_images]
    
    with torch.no_grad():
        outputs = model(images)
        _, predicted = outputs.max(1)
    
    # 显示图像和预测结果
    images = images.cpu()
    predicted = predicted.cpu()
    
    fig = plt.figure(figsize=(12, 3))
    for i in range(num_images):
        ax = fig.add_subplot(1, num_images, i + 1)
        imshow(images[i] / 2 + 0.5)
        color = 'green' if predicted[i] == labels[i] else 'red'
        ax.set_title(f'预测: {classes[predicted[i]]}\n实际: {classes[labels[i]]}', 
                    color=color)
    plt.tight_layout()
    plt.show()

# 可视化一些预测结果
visualize_predictions(model, testloader, classes)
10. 保存模型
# 保存模型
torch.save(model.state_dict(), 'vit_cifar10.pth')
print("模型已保存!")

四、ViT 的优势与挑战

优势:
  1. 全局视野:自注意力机制允许模型同时关注整个图像的所有部分
  2. 可扩展性:在大规模数据集上表现优异
  3. 统一架构:与 NLP 使用相同架构,便于多模态融合
  4. 灵活性:不受卷积操作的空间局限性限制
挑战:
  1. 数据需求:需要比 CNN 更多的训练数据
  2. 计算复杂度:自注意力计算复杂度为 O(n²),其中 n 为图像块数量
  3. 缺乏归纳偏置:不像 CNN 那样内置了图像的局部性和平移不变性

五、ViT 的变体与发展

自 ViT 提出以来,研究者们开发了多种改进版本:

  1. DeiT:通过知识蒸馏减少对大规模数据的依赖
  2. Swin Transformer:使用移动窗口注意力,降低计算复杂度
  3. CvT:结合卷积和 Transformer 的优势
  4. MAE:通过掩码自编码预训练,提高数据效率

六、实用训练技巧

在实践中,以下技巧可以提高 ViT 的训练效果:

# 数据增强
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 学习率调度
from torch.optim.lr_scheduler import CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)

# 使用混合精度训练
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()

# 训练循环中使用
for inputs, labels in trainloader:
    inputs, labels = inputs.to(device), labels.to(device)
    
    # 使用混合精度
    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, labels)
    
    # 缩放梯度并更新参数
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()
    
    # 更新学习率
    scheduler.step()

这些技巧可以显著提高 ViT 的训练效率和模型性能:

  1. 数据增强:随机裁剪、翻转和颜色抖动等可以增加训练数据的多样性,提高模型泛化能力
  2. 学习率调度:余弦退火学习率可以帮助模型更好地收敛
  3. 混合精度训练:使用 FP16 可以加速训练并减少内存占用
  4. 权重衰减:适当的权重衰减可以防止过拟合

七、使用 Hugging Face 实现 ViT

除了使用 timm 库,我们也可以使用 Hugging Face 的 Transformers 库实现 ViT。下面是一个简化的实现:

# 简单可靠的ViT实现 - CIFAR-10分类
import torch
import torchvision
import torchvision.transforms as transforms
import timm
from torch import nn, optim
from torch.utils.data import DataLoader

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 数据预处理
transform = transforms.Compose([
    transforms.Resize(224),  # ViT需要224x224输入
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载CIFAR-10
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                      download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                     download=True, transform=transform)
testloader = DataLoader(testset, batch_size=32, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

# 创建ViT模型
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)
model = model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# 训练模型
num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for i, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        # 清零梯度
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # 反向传播和优化
        loss.backward()
        optimizer.step()
        
        # 统计
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        if i % 50 == 49:
            print(f'[{epoch+1}, {i+1}] loss: {running_loss/50:.3f} | '
                  f'acc: {100.*correct/total:.2f}%')
            running_loss = 0.0
    
    # 评估
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()
    
    print(f'Epoch {epoch+1} Test Accuracy: {100.*test_correct/test_total:.2f}%')

print('训练完成!')

# 保存模型
torch.save(model.state_dict(), 'vit_cifar10.pth')

这种方法使用 Hugging Face 的 Trainer API,使得训练过程更加简洁。


八、ViT 的应用领域

ViT 及其变体已在多个计算机视觉任务中取得了出色的表现:

  1. 图像分类:在 ImageNet 等数据集上超越传统 CNN
  2. 目标检测:DETR 等结合 Transformer 的检测器表现优异
  3. 图像分割:基于 ViT 的分割模型如 SETR 提供了更好的全局上下文
  4. 视频理解:时空 Transformer 可以有效处理视频序列
  5. 医学图像分析:在医学影像分割和病变检测中表现出色
  6. 多模态学习:CLIP、DALL-E 等结合视觉与语言的模型

如果遇到问题,欢迎在评论区留言或关注微信公众号【算法最TOP】获取帮助!

你可能感兴趣的:(26秋招,大模型,人工智能,语言模型)