pytorch中的DataLoader

在 PyTorch 中,DataLoader 是一个工具类,用于高效地加载数据并准备数据输入到模型中。它支持数据的批量加载、随机打乱、并行加载和迭代操作,是训练深度学习模型的关键组件之一。


1. 基本功能

DataLoader 的主要职责是从数据集中提取样本,并根据设置返回一个批次的数据。它与 Dataset 类结合使用:

  • Dataset:定义数据集的来源、结构以及如何获取单个数据样本。
  • DataLoader:负责从 Dataset 中提取样本、组织批次以及提供训练和测试时的数据接口。

2. 初始化参数

DataLoader 的常用初始化参数如下:

torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None,
    persistent_workers=False
)
主要参数详解
  1. dataset

    • 类型:torch.utils.data.Dataset 的实例。
    • 数据的来源,必须实现 __getitem____len__ 方法。
    • 每次从 dataset 中提取一个样本。
  2. batch_size

    • 每个批次的样本数量,默认为 1
    • 例如,batch_size=32 表示每次返回32个样本。
  3. shuffle

    • 是否在每个 epoch 开始时打乱数据顺序。
    • 通常训练集设置为 True,测试集设置为 False
  4. samplerbatch_sampler

    • sampler:控制样本的抽样顺序(例如自定义采样器)。
    • batch_sampler:控制批次的抽样顺序,覆盖 batch_sizeshuffle
  5. num_workers

    • 数据加载的并行进程数,默认为 0(即仅使用主进程加载数据)。
    • 增大 num_workers 值可以加速数据加载,但需要注意系统资源。
  6. collate_fn

    • 自定义批处理函数,定义如何将多个样本合成一个批次。
    • 默认将每个样本拼接为批量张量(如果可能)。
  7. pin_memory

    • 是否将数据加载到固定内存中(Pinned Memory),提高 GPU 训练时数据传输效率。
  8. drop_last

    • 是否丢弃最后一个不足 batch_size 的批次。
    • 默认值为 False,即保留不足的最后一个批次。
  9. timeout

    • 数据加载的超时时间(秒)。
    • 若超时未加载成功,会抛出异常。
  10. worker_init_fn

    • 为每个工作进程设置初始化函数,常用于自定义子进程的随机种子等设置。
  11. persistent_workers(PyTorch 1.7+ 引入)persistent_workers (PyTorch 1.7+ 引入)

    • 若设置为 Truenum_workers 启动的工作进程在每个 epoch 后不会被销毁,从而减少重新启动的开销。

3. 使用示例

a. 基本使用
from torch.utils.data import DataLoader, Dataset

# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

# 示例数据
data = [i for i in range(10)]
dataset = CustomDataset(data)

# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=3, shuffle=True)

# 迭代访问数据
for batch in dataloader:
    print(batch)

输出(随机打乱后的批次):

tensor([6, 8, 3])
tensor([1, 4, 5])
tensor([7, 0, 9])
tensor([2])

b. 使用多进程加载
dataloader = DataLoader(dataset, batch_size=3, num_workers=2)

for batch in dataloader:
    print(batch)
  • num_workers=2 表示使用两个子进程并行加载数据。

c. 使用自定义 collate_fn
# 自定义 collate 函数:将批次数据加倍
def custom_collate_fn(batch):
    return [x * 2 for x in batch]

dataloader = DataLoader(dataset, batch_size=3, collate_fn=custom_collate_fn)

for batch in dataloader:
    print(batch)

输出

[12, 16, 6]
[2, 8, 10]
[14, 0, 18]
[4]

4. 工作原理

  1. 抽样

    • 数据加载器根据 samplershuffle 参数,确定要加载的样本索引。
  2. 加载样本

    • 调用 Dataset__getitem__ 方法加载单个样本。
  3. 批处理

    • 将多个样本通过 collate_fn 组合成一个批次。
  4. 并行加载

    • 如果设置了 num_workers > 0,会启动多个子进程来加速数据加载。

5. 注意事项

  1. GPU 加速与 pin_memory
    如果使用 GPU 训练,设置 pin_memory=True 可以加速主机到 GPU 的数据传输。

  2. 批量大小的选择

    • 批量大小会影响显存占用和模型训练效率。
    • 通常设置为 2 的幂(如 16、32、64)以便于硬件优化。
  3. 并行加载的平衡

    • num_workers 的选择应根据 CPU 核心数和硬件性能调整。
    • 太多子进程可能导致资源竞争,从而降低性能。
  4. 动态形状的数据

    • 若数据形状不一致(例如 NLP 中的文本长度不同),需提供自定义 collate_fn 进行补齐或处理。

6. 最后

DataLoader 是 PyTorch 中高效的数据加载工具,通过与 Dataset 配合,能够高效、灵活地提供数据批次。它支持:

  • 数据的随机打乱和批量化。
  • 自定义的数据加载和处理逻辑。
  • 多线程并行加载,提高数据预处理效率。

你可能感兴趣的:(pytorch,人工智能,python)