在学习的过程中,遇到很多的问题.在训练模型的时候,需要对数据进行读取操作.本篇文章,介绍pytorch如何自定义数据dataset和dataloader.在pytorch中,提供了Dataset这个类,负责对数据进行抽象,一次调用只返回第一个同样本.而Dataloader提供了对一个对一个batch的数据操作,还有对数据进行shuffle等功能.
Dataset是一个抽象类,表示一个数据集.所有的数据集类都是Dataset的子类.在自定义数据集的时候,需要对函数**__getitem()**, **__len()**重写.
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
#每调用一次,返回一个数据样本, index表示返回样本的索引位置
#在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在__getitem__函数中将出现异常,此时最好的解决方案即是将出错的样本剔除。然后在数据集中,随机加入一条数据
def __getitem__(self, index):
raise NotImplementedError
#__len__ 数据集的长度
def __len__(self):
raise NotImplementedError
#增加数据,对数据集进行拼接
def __add__(self, other):
return ConcatDataset([self, other])
DataLoader是一个可迭代的对象,意味着我们可以像使用迭代器一样使用它或者 batch_datas, batch_labels in dataloader:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
'''
dataset:加载的数据集(Dataset对象)
batch_size:batch size
shuffle::是否将数据打乱
sampler: 样本抽样,后续会详细介绍
num_workers:使用多进程加载的进程数,0代表不使用多进程
collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃
'''
在本次的例子中,我们利用Dataset和DataLoader处理预测姓名的性别的原始数据.下面是完整的例子.
# coding=utf-8
import random
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn as nn
################################load dataset ###################
################################################################
def load_data(path, label):
names = []
with open(path) as f:
lines = f.readlines()
for l in lines:
names.append((l.strip('\n'), label))
return names
female_names = load_data('../datasets/names_gender/eng/female.txt', 0)
male_names = load_data('../datasets/names_gender/eng/male.txt', 1)
names = female_names + male_names
random.shuffle(names)
# 将数据划分为训练集和测试集
train_dataset = names[: int(len(names)*0.9)]
test_dataset = names[int(len(names)*0.9)]
class NameDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
# 这里可以对数据进行处理,比如讲字符数值化
data = self.data[index]
name = data[0]
label = data[1]
return name, label
def __len__(self):
return len(self.data)
train_dataset = NameDataset(train_dataset)
test_dataset = NameDataset(test_dataset)
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=20, shuffle=True)
for data in train_loader:
print(data[0])
break
#结果
('Danelle', 'Ransell', 'Tanya', 'Rutter', 'Herschel')
tensor([0, 1, 0, 1, 1])