关键词:少样本学习、PyTorch、元学习、支持集、原型网络
摘要:传统深度学习需要"海量数据喂养",但现实中很多场景(如罕见病诊断、新物种识别)只有少量样本。本文将用"小学生考试"的比喻,带您一步步理解少样本学习(Few-shot Learning)的核心原理,并用PyTorch实现一个能"看5张图学会新类别"的原型网络(Prototypical Networks),最后结合医疗、金融等真实场景,揭示这项技术如何让AI从"数据吃货"变身"学习高手"。
想象一下:医生拿到一种从未见过的罕见病患者的5张CT图像,希望AI能快速识别同类病例;程序员需要让客服机器人识别"用户抱怨快递盒破损"这个新意图,但只有3条标注数据。这些场景中,传统深度学习(需要成千上万张标注图)完全失效,少样本学习(Few-shot Learning, FSL)正是为解决这类问题而生。本文将覆盖少样本学习的核心概念、经典算法(原型网络)、PyTorch实战代码,以及真实应用场景。
本文将按照"概念理解→原理拆解→代码实战→场景落地"的逻辑展开:先用"考试小抄"的故事解释少样本学习;再拆解核心概念(支持集、元学习等)并画流程图;接着用PyTorch实现原型网络,逐行解读代码;最后结合医疗/金融场景说明应用价值,推荐学习资源。
术语 | 解释(用小学生能懂的话) |
---|---|
支持集(Support Set) | 像考试前老师给的"小抄":AI学习新任务时,用来"参考"的少量样本(比如5张猫的照片) |
查询集(Query Set) | 像考试题目:AI学完"小抄"后需要识别的新样本(比如1张没见过的猫的照片) |
元学习(Meta-Learning) | "学习如何学习"的超能力:AI不是死记硬背具体知识,而是学会"如何快速掌握新知识"的方法(就像学生学会高效记笔记的方法,新科目也能快速上手) |
原型(Prototype) | 同类样本的"平均特征":比如把5张猫的照片的特征"揉成一团",得到一个代表猫的"典型特征",新照片和它越像就是猫 |
小明是个转学生,每天要转去新学校上不同的课(比如周一学火星语,周二学恐龙分类)。但每门新课他只能看到5页课本(支持集),然后就要考试(用查询集测试)。普通学生只能死记硬背这5页,遇到没见过的题目就抓瞎。但小明学会了"学习方法"(元学习):他发现所有课本的知识都能提炼成"关键词特征"(比如火星语的"尖刺发音"、恐龙的"尾巴长度"),然后把5页课本的关键词揉成"典型特征"(原型),考试时只要新题目和哪个"典型特征"像,就选哪个答案。少样本学习中的AI,就像小明这样的"学习高手"。
元学习(Meta-Learning)的核心是"学习如何学习"。传统深度学习像"填鸭式学生",只能记住具体知识(比如只认识常见的猫);元学习像"学习方法大师",它从大量"小任务"(比如学过的1000种动物分类任务)中总结出"如何用少量样本快速学习新任务"的方法。就像学生通过做大量试卷,总结出"先划重点再记忆"的学习方法,遇到新科目也能快速上手。
原型网络(Prototypical Networks)是少样本学习的经典算法。它的思路很简单:给每个类别计算一个"典型特征"(原型),新样本和哪个原型最像,就属于哪个类别。比如支持集中有5张猫的照片,先把每张照片用神经网络转换成特征向量(像给照片做"特征指纹"),然后把这5个指纹取平均,得到猫的"原型指纹"。新照片的指纹和猫的原型指纹距离近,就判断为猫。
少样本学习的核心流程:
graph TD
A[元训练阶段] --> B[历史任务1: 支持集S1+查询集Q1]
A --> C[历史任务2: 支持集S2+查询集Q2]
A --> D[...多个历史任务...]
B --> E[提取S1中每个样本的特征]
C --> E
D --> E
E --> F[计算每个类别的原型(特征平均)]
F --> G[用原型对Q1/Q2/...分类,计算损失并更新模型]
G --> H[模型学会"如何生成原型并匹配查询集"]
H --> I[元测试阶段]
I --> J[新任务: 支持集S_new+查询集Q_new]
J --> K[提取S_new特征,生成新原型]
K --> L[用新原型对Q_new分类]
原型网络的核心是"特征提取→计算原型→距离匹配"三步骤,我们用PyTorch实现这一过程。
需要一个神经网络(称为"嵌入网络",Encoder),把图像(或其他数据)转换成固定长度的特征向量。就像把中文翻译成英文,嵌入网络把"像素矩阵"翻译成"特征语言",让计算机能理解图像的"内在含义"。
假设支持集有N个类别(N-way),每个类别有K个样本(K-shot),总共有N×K个样本。对每个类别,取该类别K个样本的特征向量的平均值,得到该类的原型(Prototype)。数学上表示为:
p c = 1 K ∑ i = 1 K f ( x c , i ) p_c = \frac{1}{K} \sum_{i=1}^K f(x_{c,i}) pc=K1i=1∑Kf(xc,i)
其中, f ( x ) f(x) f(x)是嵌入网络的输出(特征向量), p c p_c pc是类别c的原型。
对于查询集中的每个样本 x q x_q xq,计算它的特征 f ( x q ) f(x_q) f(xq)与所有类别原型 p c p_c pc的距离(常用欧氏距离或余弦相似度)。距离最近的类别即为预测结果。欧氏距离公式:
d ( f ( x q ) , p c ) = ∑ i = 1 D ( f ( x q ) i − p c i ) 2 d(f(x_q), p_c) = \sqrt{\sum_{i=1}^D (f(x_q)_i - p_c^i)^2} d(f(xq),pc)=i=1∑D(f(xq)i−pci)2
其中D是特征向量的维度(比如64维)。
import torch
import torch.nn as nn
import torch.nn.functional as F
# 步骤1:定义嵌入网络(简单CNN为例)
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 64, 3), # 输入1通道(灰度图),输出64通道
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2), # 下采样,缩小图像尺寸
# 重复3次卷积块,最终输出64维特征
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.fc = nn.Linear(64, 64) # 全连接层输出64维特征
def forward(self, x):
x = self.conv(x) # 卷积后得到 (batch, 64, 1, 1)
x = x.view(x.size(0), -1) # 展平为 (batch, 64)
x = self.fc(x) # 输出64维特征向量
return x
# 步骤2:计算原型(在训练/测试时动态计算)
def compute_prototypes(support_features, support_labels, num_classes):
# support_features: (N*K, 64),N是类别数,K是样本数
# support_labels: (N*K,),类别标签(0到N-1)
prototypes = []
for c in range(num_classes):
# 选出属于类别c的所有样本的特征
mask = (support_labels == c)
class_features = support_features[mask]
# 计算平均特征(原型)
prototype = class_features.mean(dim=0)
prototypes.append(prototype)
return torch.stack(prototypes) # (N, 64)
# 步骤3:计算查询样本与原型的距离(欧氏距离)
def euclidean_distance(query_features, prototypes):
# query_features: (Q, 64),Q是查询样本数
# prototypes: (N, 64)
# 计算每个查询样本与所有原型的距离
# 公式:(a - b)^2 = a² + b² - 2ab
a_sq = query_features.pow(2).sum(dim=1, keepdim=True) # (Q, 1)
b_sq = prototypes.pow(2).sum(dim=1) # (N,)
ab = query_features @ prototypes.T # (Q, N)
distances = a_sq + b_sq - 2 * ab # (Q, N)
return distances
原型网络的训练目标是:让查询样本的预测类别尽可能接近真实标签。预测概率通过"负距离"的softmax计算(距离越近,概率越高):
p ( c ∣ x q ) = exp ( − d ( f ( x q ) , p c ) ) ∑ c ′ = 1 N exp ( − d ( f ( x q ) , p c ′ ) ) p(c|x_q) = \frac{\exp(-d(f(x_q), p_c))}{\sum_{c'=1}^N \exp(-d(f(x_q), p_{c'}))} p(c∣xq)=∑c′=1Nexp(−d(f(xq),pc′))exp(−d(f(xq),pc))
损失函数使用交叉熵:
L = − 1 Q ∑ q = 1 Q log ( p ( y q ∣ x q ) ) \mathcal{L} = -\frac{1}{Q} \sum_{q=1}^Q \log(p(y_q|x_q)) L=−Q1q=1∑Qlog(p(yq∣xq))
其中 y q y_q yq是查询样本 x q x_q xq的真实标签, Q Q Q是查询集样本数。
pip install torch torchvision
)我们将实现一个完整的原型网络训练流程,包含数据加载、模型定义、训练循环。
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
class OmniglotDataset(Dataset):
def __init__(self, root, split='train', n_way=5, k_shot=1, q_query=15):
self.root = os.path.join(root, split)
self.n_way = n_way # 每个任务的类别数(N-way)
self.k_shot = k_shot # 每个类别的支持样本数(K-shot)
self.q_query = q_query # 每个类别的查询样本数
self.transform = transforms.Compose([
transforms.Resize((28, 28)), # 缩放到28x28
transforms.ToTensor(), # 转成Tensor(0-1)
transforms.Normalize(mean=[0.5], std=[0.5]) # 归一化到-1~1
])
# 收集所有字符路径(每个字符是一个类别)
self.classes = [os.path.join(self.root, family, char)
for family in os.listdir(self.root)
for char in os.listdir(os.path.join(self.root, family))]
self.classes = sorted(self.classes)
def __getitem__(self, index):
# 随机选择n_way个类别
selected_classes = np.random.choice(self.classes, self.n_way, replace=False)
support_images = []
support_labels = []
query_images = []
query_labels = []
for label, cls in enumerate(selected_classes):
# 每个类别选k_shot + q_query张图片
all_images = [os.path.join(cls, img) for img in os.listdir(cls)]
selected_images = np.random.choice(all_images, self.k_shot + self.q_query, replace=False)
# 前k_shot张作为支持集
for img_path in selected_images[:self.k_shot]:
img = Image.open(img_path).convert('L') # 转为灰度图
support_images.append(self.transform(img))
support_labels.append(label)
# 后q_query张作为查询集
for img_path in selected_images[self.k_shot:]:
img = Image.open(img_path).convert('L')
query_images.append(self.transform(img))
query_labels.append(label)
# 合并成Tensor并打乱顺序(PyTorch需要)
support_images = torch.stack(support_images) # (n_way*k_shot, 1, 28, 28)
support_labels = torch.tensor(support_labels) # (n_way*k_shot,)
query_images = torch.stack(query_images) # (n_way*q_query, 1, 28, 28)
query_labels = torch.tensor(query_labels) # (n_way*q_query,)
return support_images, support_labels, query_images, query_labels
def __len__(self):
return 1000 # 自定义epoch数(实际可根据需求调整)
def train():
# 参数设置
n_way = 5
k_shot = 1
q_query = 15
lr = 0.001
epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化模型、数据加载器、优化器
encoder = Encoder().to(device)
dataset = OmniglotDataset(root='omniglot', split='train',
n_way=n_way, k_shot=k_shot, q_query=q_query)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True) # 每个batch是一个任务
optimizer = torch.optim.Adam(encoder.parameters(), lr=lr)
for epoch in range(epochs):
encoder.train()
total_loss = 0
total_acc = 0
for batch in dataloader:
# 加载支持集和查询集(batch_size=1,所以取[0])
support_imgs, support_labels, query_imgs, query_labels = [x[0].to(device) for x in batch]
# 步骤1:提取支持集和查询集的特征
support_features = encoder(support_imgs) # (n_way*k_shot, 64)
query_features = encoder(query_imgs) # (n_way*q_query, 64)
# 步骤2:计算原型(n_way个类别,每个类别k_shot样本)
prototypes = compute_prototypes(support_features, support_labels, n_way) # (n_way, 64)
# 步骤3:计算查询样本与原型的距离(欧氏距离)
distances = euclidean_distance(query_features, prototypes) # (n_way*q_query, n_way)
# 步骤4:计算预测概率(负距离的softmax)
logits = -distances # 距离越近,logits越大(因为要最大化概率)
probs = F.softmax(logits, dim=1)
# 步骤5:计算交叉熵损失
loss = F.cross_entropy(logits, query_labels)
# 步骤6:反向传播更新模型
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 计算准确率
preds = probs.argmax(dim=1)
acc = (preds == query_labels).float().mean()
total_loss += loss.item()
total_acc += acc.item()
# 打印每轮训练结果
avg_loss = total_loss / len(dataloader)
avg_acc = total_acc / len(dataloader)
print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}')
if __name__ == '__main__':
train()
OmniglotDataset
将数据组织成"任务"(每个任务包含支持集和查询集),模拟少样本场景。例如,n_way=5
表示每个任务要学5个新类别,k_shot=1
表示每个类别只有1张支持图。Encoder
是一个简单的CNN,将28x28的灰度图转换为64维特征向量。卷积层的作用是提取图像的局部特征(如线条、弧度),全连接层将其整合为全局特征。compute_prototypes
函数通过平均支持集特征生成原型,这是少样本学习的核心——用少量样本总结出"典型特征"。医院数据库中可能只有5-10张某罕见病的CT图像(如"肺淋巴管肌瘤病"),传统深度学习无法训练。少样本学习可以用这些少量样本生成该病的"特征原型",快速识别新病例。
新的欺诈手段(如"虚假海外购物退款")刚出现时,只有少量标注案例。少样本学习能基于这些案例总结欺诈特征,实时检测新的欺诈行为。
用户可能自定义新的语音指令(如"打开儿童房氛围灯"),但只有3-5次历史录音。少样本学习可以快速学习这些新指令,无需重新训练大模型。
类型 | 推荐资源 | 说明 |
---|---|---|
数据集 | Omniglot(PyTorch可直接加载) | 少样本学习经典手写字符数据集 |
Mini-ImageNet(需要手动下载) | 更接近真实场景的小样本图像数据集 | |
开源库 | Torchmeta(https://github.com/tristandeleu/pytorch-meta) | PyTorch的元学习专用库,包含数据加载器和经典模型(如原型网络) |
论文 | 《Prototypical Networks for Few-shot Learning》(2017) | 原型网络的原始论文,详细推导数学原理 |
《Matching Networks for One Shot Learning》(2016) | 少样本学习早期经典论文,提出匹配网络 | |
教程 | PyTorch官方少样本学习教程(https://pytorch.org/tutorials/intermediate/meta_learning_tutorial.html) | 官方实战教程,包含MAML(另一种元学习算法)的实现 |
GPT-3等大模型展示了强大的"上下文学习"(In-Context Learning)能力,本质上是少样本学习。未来少样本学习可能与LLM深度融合,通过"提示(Prompt)"让模型仅用少量示例完成复杂任务(如代码生成、文本分类)。
当前少样本学习多集中在单模态(如图像),未来可能扩展到跨模态(如图像+文本)。例如,用少量"图像+描述"对,让模型学会根据文本描述生成或识别图像。
支持集样本极少时,模型容易记住个别样本的噪声(如照片中的光线干扰),导致原型不准确。需要更鲁棒的特征提取方法(如引入先验知识或数据增强)。
元训练阶段的历史任务可能与实际应用任务差异大(如用手写字符训练的模型,用于医学影像任务效果差)。需要研究"跨领域少样本学习",提升模型的泛化性。
支持集是"学习材料",元学习是"学习方法",原型网络是"具体技巧"。三者协作,让AI从"数据吃货"变身"学习高手",在样本极少的场景下也能高效工作。
Q1:少样本学习和零样本学习(Zero-shot Learning)有什么区别?
A:少样本学习需要少量样本(如5张),零样本学习不需要任何新类别样本(通过先验知识,如图像的文本描述)。例如,识别"霍加狓"时,少样本学习需要5张霍加狓的图,零样本学习需要知道"霍加狓是条纹像斑马、体型像长颈鹿的动物"的描述。
Q2:支持集的样本数(k-shot)越多越好吗?
A:不一定。当k增大到一定程度(如20),少样本学习和传统监督学习效果接近,但少样本学习的优势是"用更少样本达到相近效果"。实际中,k通常取1-5,因为更大的k可能增加标注成本。
Q3:原型网络和匹配网络(Matching Networks)有什么区别?
A:原型网络用"平均特征"作为原型,匹配网络用"注意力机制"让查询样本直接与支持集每个样本匹配(类似"记住所有小抄细节")。原型网络计算更简单,匹配网络理论上能利用支持集的更多细节,但容易过拟合。