pytorch实现数据数据读取

文章目录

  • Dataset
  • DataLoader
  • 例子

本文原地址为 www.zicesun.com icesuns.github.io

在学习的过程中,遇到很多的问题.在训练模型的时候,需要对数据进行读取操作.本篇文章,介绍pytorch如何自定义数据dataset和dataloader.在pytorch中,提供了Dataset这个类,负责对数据进行抽象,一次调用只返回第一个同样本.而Dataloader提供了对一个对一个batch的数据操作,还有对数据进行shuffle等功能.

Dataset

Dataset是一个抽象类,表示一个数据集.所有的数据集类都是Dataset的子类.在自定义数据集的时候,需要对函数**__getitem()**, **__len()**重写.

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """
    #每调用一次,返回一个数据样本, index表示返回样本的索引位置
    #在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在__getitem__函数中将出现异常,此时最好的解决方案即是将出错的样本剔除。然后在数据集中,随机加入一条数据
    def __getitem__(self, index):
        raise NotImplementedError
    #__len__ 数据集的长度
    def __len__(self):
        raise NotImplementedError
    #增加数据,对数据集进行拼接
    def __add__(self, other):
        return ConcatDataset([self, other])

DataLoader

DataLoader是一个可迭代的对象,意味着我们可以像使用迭代器一样使用它或者 batch_datas, batch_labels in dataloader:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, 
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
'''
dataset:加载的数据集(Dataset对象) 
batch_size:batch size 
shuffle::是否将数据打乱 
sampler: 样本抽样,后续会详细介绍 
num_workers:使用多进程加载的进程数,0代表不使用多进程 
collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可 
pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些 
drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃
'''

例子

在本次的例子中,我们利用Dataset和DataLoader处理预测姓名的性别的原始数据.下面是完整的例子.

# coding=utf-8
import random
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn as nn


################################load dataset ###################
################################################################
def load_data(path, label):
    names = []
    with open(path) as f:
        lines = f.readlines()
        for l in lines:
            names.append((l.strip('\n'), label))
    return names
female_names = load_data('../datasets/names_gender/eng/female.txt', 0)
male_names = load_data('../datasets/names_gender/eng/male.txt', 1)
names = female_names + male_names
random.shuffle(names)

# 将数据划分为训练集和测试集
train_dataset = names[: int(len(names)*0.9)]
test_dataset = names[int(len(names)*0.9)]

class NameDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        # 这里可以对数据进行处理,比如讲字符数值化
        data = self.data[index]
        name = data[0]
        label = data[1]
        return name, label

    def __len__(self):
        return len(self.data)

train_dataset = NameDataset(train_dataset)
test_dataset = NameDataset(test_dataset)
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=20, shuffle=True)


for data in train_loader:
    print(data[0])
    break
#结果
('Danelle', 'Ransell', 'Tanya', 'Rutter', 'Herschel')
tensor([0, 1, 0, 1, 1])

你可能感兴趣的:(深度学习,机器学习)