pytorch中加载训练数据集的两种方法

在pytorch中使用torch.utils.data.Dataset创建训练数据集,使用torch.utils.data.DataLoader数据加载器加载出一个batch size的数据。因此为了完成一个epoch中对于整个训练数据集的学习,需要对于data.DataLoader进行访问,无论是通过enumerate还是iter方法,其返回对象都是一个列表,列表长度与所定义的dataset的__getitem__方法的返回值个数相等,其中列表中的每个元素是torch.Tensor类型,shape在第0个维度上等于batch size,在剩下的维度上与__getitem__()方法的相应返回值相同,相当于将batch size个__getitem__()方法的相应返回值进行stack,并在第0个纬度上进行dimension expand。

import torch
import torch.nn as nn
import torch.utils.data as data
class train_set(data.Dataset):
    def __init__(self):
        self.data=[1,2,3,4,5,6,7]
    def __getitem__(self, item):
        return torch.ones((2,2)),self.data[item]
    def __len__(self):
        return len(self.data)

train_set=train_set()
dataloader=data.DataLoader(train_set,batch_size=2,shuffle=False)
# for j in range(1):
#     data_iter=iter(dataloader)
#     num_image=4
#     # 循环:
#     for i in range(num_image):
#         x=next(data_iter)
#         print(type(x),len(x))# 2
        # print(type(x[0]), x[0].shape)# torch.Size([2, 2, 2])
        # print(type(x[1]),x[1].shape)# torch.Size([2])
    # print('---------我是分割线-----------')

for j in range(1):
    # data_iter=iter(dataloader)
    # num_image=4
    # # 循环:
    num_image=4
    for i,data in enumerate(dataloader):
        print(type(data),len(data))# 2
        print(type(data[0]),data[0].shape)# torch.Size([2, 2, 2])
        print(type(data[1]),data[1].shape)# torch.Size([2])
        if i==num_image-1:
            break
    # print('---------我是分割线-----------')

你可能感兴趣的:(pytorch)