PyTorch DataLoader、DataSet

研究了下DataLoader大批量加载数据的原理:DataLoader只负责数据的抽象,一次调用getitem只返回一个样本

import torch
from torch.utils.data import DataLoader,Dataset
import numpy as np
import pandas as pd

class DataSet(Dataset):
    def __init__(self):
        data = np.loadtxt('/Users/yq/Desktop/test_data/new.csv',delimiter=',',dtype=np.float32)
        self.x_data = torch.from_numpy(data[:,0:-1])
        self.y_data = torch.from_numpy(data[:,-2:)
        self.len = data.shape[0]
    def __getitem__(self,index):
        return self.x_data[index],self.y_data[index]
    def __len__(self):
        return self.len


dataset = DataSet()
data_loader = DataLoader(dataset=dataset,batch_size=5,shuffle=True)


for i,data in enumerate(data_loader):
    item,label = data
    print(item)
    print(label)

结果示例如下:

PyTorch DataLoader、DataSet_第1张图片

在DataLoader中设置了batch_size,shuffle,查看了下源码

PyTorch DataLoader、DataSet_第2张图片

PyTorch DataLoader、DataSet_第3张图片

for i,data in enumerate(data_loader):

i是range(self.len / batch_size),表示迭代的批次,而每次迭代时,batch_size行数据因为shuffle会被随机选择,如上图所示。

在这里__getitem__内置函数的作用在于:在当前的shuffle序列[9,22,15,16,19],index=9,则返回一行数据(self.x_data[index],self.y_data[index]),以此类推,迭代了len(batch_size),获取了当前批次的数据,开始准备进行训练。

你可能感兴趣的:(NLP)