PyTorch中的Dataset
是一个抽象类,所有自定义的数据集都应该继承这个类,并实现以下两个方法:
__len__()
: 返回数据集的大小__getitem__()
: 根据索引返回一个样本概念解析:
Dataset
类提供了统一的数据访问接口Dataset
,我们可以轻松地将数据集成到PyTorch的生态系统中import torch
from torch.utils.data import Dataset
from PIL import Image
import os
class CustomImageDataset(Dataset):
def __init__(self, img_dir, transform=None):
"""
初始化数据集
参数:
img_dir (str): 图像目录路径
transform (callable, optional): 可选的数据变换函数
"""
self.img_dir = img_dir
self.transform = transform
# 假设目录中每个文件都是一个图像样本
self.img_names = os.listdir(img_dir)
def __len__(self):
"""返回数据集中的样本数"""
return len(self.img_names)
def __getitem__(self, idx):
"""根据索引获取样本"""
img_path = os.path.join(self.img_dir, self.img_names[idx])
image = Image.open(img_path).convert('RGB') # 转换为RGB格式
if self.transform:
image = self.transform(image)
# 假设文件名格式为"label_image.jpg"
label = int(self.img_names[idx].split('_')[0])
return image, label
注意事项:
__getitem__
中执行耗时的操作(如文件读取)可能会影响性能__len__
返回的是整数__init__
中完成,而不是在__getitem__
中DataLoader
是PyTorch提供的一个迭代器,它封装了Dataset
,提供以下功能:
主要参数:
dataset
: 要加载的数据集(Dataset对象)batch_size
: 每个batch的大小(默认: 1)shuffle
: 是否在每个epoch打乱数据(默认: False)num_workers
: 用于数据加载的子进程数(默认: 0,表示在主进程中加载)pin_memory
: 是否将数据复制到CUDA固定内存中(加速GPU传输)drop_last
: 是否丢弃最后一个不完整的batch(默认: False)实例:
from torch.utils.data import DataLoader
# 创建数据集实例
dataset = CustomImageDataset(img_dir='data/images', transform=None)
# 创建DataLoader
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4,
pin_memory=True
)
# 使用示例
for batch_images, batch_labels in dataloader:
# batch_images是形状为[32, C, H, W]的张量
# batch_labels是形状为[32]的张量
pass
注意事项:
num_workers
设置过高可能导致内存问题if __name__ == '__main__':
块中pin_memory
在GPU训练时可以显著提高数据传输速度batch_size
整除时,最后一个batch会比较小PyTorch的torchvision.datasets
模块提供了许多常用的数据集:
实例:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 定义数据变换
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 下载并加载MNIST数据集
train_dataset = datasets.MNIST(
root='./data', # 数据存储路径
train=True, # 训练集
download=True, # 如果数据不存在则下载
transform=transform # 应用的数据变换
)
test_dataset = datasets.MNIST(
root='./data',
train=False, # 测试集
download=True,
transform=transform
)
# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
注意事项:
root
参数指定数据存储位置,确保有足够的磁盘空间PyTorch提供了多种数据预处理变换:
基础变换:
ToTensor()
: 将PIL图像或NumPy数组转换为PyTorch张量Normalize(mean, std)
: 用给定的均值和标准差对张量进行归一化Resize(size)
: 调整图像大小CenterCrop(size)
: 中心裁剪Pad(padding)
: 填充几何变换:
RandomHorizontalFlip(p=0.5)
: 随机水平翻转RandomVerticalFlip(p=0.5)
: 随机垂直翻转RandomRotation(degrees)
: 随机旋转RandomAffine(degrees, translate, scale, shear)
: 随机仿射变换颜色变换:
ColorJitter(brightness, contrast, saturation, hue)
: 随机改变亮度、对比度、饱和度和色调Grayscale(num_output_channels=1)
: 转换为灰度图RandomGrayscale(p=0.1)
: 以概率p转换为灰度图组合变换:
Compose([transforms])
: 将多个变换组合在一起from torchvision import transforms
# 定义训练集变换
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224), # 随机裁剪并调整大小
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), # 随机颜色变换
transforms.ToTensor(), # 转换为张量
transforms.Normalize( # 归一化
mean=[0.485, 0.456, 0.406], # ImageNet均值
std=[0.229, 0.224, 0.225] # ImageNet标准差
)
])
# 定义测试集变换 (通常不包含数据增强)
test_transform = transforms.Compose([
transforms.Resize(256), # 调整大小
transforms.CenterCrop(224), # 中心裁剪
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
注意事项:
ToTensor()
通常应该在Normalize()
之前我们可以通过实现__call__
方法创建自定义变换:
import random
import torchvision.transforms.functional as F
class RandomRotationWithReflect:
"""随机旋转并保持图像大小不变(通过反射填充)"""
def __init__(self, degrees):
self.degrees = degrees
def __call__(self, img):
angle = random.uniform(-self.degrees, self.degrees)
return F.rotate(img, angle, fill=(0, 0, 0)) # 用黑色填充
# 使用自定义变换
custom_transform = transforms.Compose([
transforms.Resize(256),
RandomRotationWithReflect(30), # 自定义变换
transforms.ToTensor()
])
注意事项:
__call__
中执行耗时操作下面是一个完整的图像分类任务数据加载和预处理示例:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import os
from PIL import Image
# 1. 自定义数据集类
class CustomImageDataset(Dataset):
def __init__(self, img_dir, transform=None):
self.img_dir = img_dir
self.transform = transform
self.classes = sorted(os.listdir(img_dir)) # 假设每个子目录是一个类别
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
self.samples = []
# 收集所有图像路径和标签
for cls in self.classes:
cls_dir = os.path.join(img_dir, cls)
for img_name in os.listdir(cls_dir):
self.samples.append((
os.path.join(cls_dir, img_name),
self.class_to_idx[cls]
))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
img = Image.open(img_path).convert('RGB')
if self.transform:
img = self.transform(img)
return img, label
# 2. 定义数据变换
# 训练集变换 (包含数据增强)
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(30),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 验证集/测试集变换 (不包含数据增强)
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 3. 创建数据集
train_dataset = CustomImageDataset(
img_dir='data/train',
transform=train_transform
)
val_dataset = CustomImageDataset(
img_dir='data/val',
transform=val_transform
)
# 4. 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=64,
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=True
)
val_loader = DataLoader(
val_dataset,
batch_size=64,
shuffle=False,
num_workers=4,
pin_memory=True
)
# 5. 使用示例
def train_model():
for epoch in range(10):
# 训练阶段
model.train()
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
# 训练代码...
# 验证阶段
model.eval()
with torch.no_grad():
for images, labels in val_loader:
images = images.to(device)
labels = labels.to(device)
# 验证代码...
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_model()
Dataset类:
__len__
和__getitem__
方法DataLoader:
batch_size
, shuffle
, num_workers
, pin_memory
标准数据集:
torchvision.datasets
提供了常用数据集数据变换:
transforms.Compose
用于组合多个变换ToTensor
通常在Normalize
之前)最佳实践:
num_workers
加速数据加载pin_memory
通过合理的数据加载和预处理流程,可以显著提高模型训练效率和最终性能。数据预处理应该与模型的需求和数据的特性相匹配,同时考虑计算效率和内存使用。