pytorch中Dataloader的使用

1.大致流程

pytorch中加载数据的顺序是:
①创建一个dataset对象
②创建一个dataloader对象
③循环dataloader对象,将data,label拿到模型中去训练

2.Dataset

你需要自己定义一个class,里面至少包含3个函数:

(特别要注意的是输入进函数的数据一定得是可迭代的。如果是自定的数据集的话可以在定义类中用def__len__、def__getitem__定义。)
①__init__:传入数据,或者像下面一样直接在函数里加载数据
②__len__:返回这个数据集一共有多少个item
③__getitem__:返回一条训练数据,并将其转换成tensor

3.Dataloader

参数:
dataset:传入的数据
shuffle = True:是否打乱数据
collate_fn:使用这个参数可以自己操作每个batch的数据

(collate_fn暂时用不到,可以参考Pytorch中DataLoader的使用_kahuifu的博客-CSDN博客_dataloader)

4.按照batch取数据和标签

5.代码

import torch
from torch.utils.data import DataLoader,Dataset
import numpy as np

class Mydata(Dataset):
    def __init__(self, train_x, train_label):
        self.train_x = train_x
        self.train_label = train_label

    def __getitem__(self, item):
        assert item

执行结果

(注:如果在定义Dataset类的时候,在方法__getitem__中不加入return self.train_label[item]这一条命令的话,最终在按照batch取数据的时候,不会取label,只会取训练数据)

0 [tensor([[0., 0., 0.],
        [0., 0., 0.]], dtype=torch.float64)]
tensor([[1],
        [0]], dtype=torch.int32)
1 [tensor([[0., 0., 0.],
        [0., 0., 0.]], dtype=torch.float64)]
tensor([[3],
        [2]], dtype=torch.int32)

Process finished with exit code 0

你可能感兴趣的:(pytorch,深度学习,人工智能)