class CustomDataset(Dataset):
# 自定义自己的dataset
dataset = CustomDataset()
dataloader = Dataloader(dataset, ...)
for data in dataloader:
# training...
__iter__
() 方法, 产生了一个DataLoaderIter
__next__()
来得到batch, 具体操作就是, 多次调用Dataset的__getitem__()
方法 (如果num_worker>0
就多线程调用), 然后用collate_fn
来把它们打包成batch. 中间还会涉及到shuffle
, 以及sample
的方法等, 这里就不多说了.__next__()
抛出一个StopIteration异常, for循环结束, dataloader 失效__getitem __()
,以获取指定Key的数据样本__len__()
,以返回整个数据集的长度class my_dataset(torch.utils.data.Dataset):
def __init__(self, trainingImageDir, bndbox, keypointsPixel, keypointsWorld, center):
self.trainingImageDir = trainingImageDir
self.mean = Img_mean
self.std = Img_std
self.bndbox = bndbox
self.keypointsPixel = keypointsPixel
self.keypointsWorld = keypointsWorld
self.center = center
self.depth_thres = 0.4
def __getitem__(self, index):
# data4DTemp = scio.loadmat(self.trainingImageDir + str(index+1) + '.mat')['DepthNormal']
data4DTemp = scio.loadmat(self.trainingImageDir + str(index) + '.mat')['DepthNormal']
depthTemp = data4DTemp[:,:,3]
data, label = dataPreprocess(index, depthTemp, self.keypointsPixel, self.keypointsWorld, self.bndbox, self.center, self.depth_thres)
return data, label
def __len__(self):
return len(self.bndbox)
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
参数 | 说明 |
---|---|
dataset | 从中加载数据的数据集 |
batch_size | 每批次要加载的样本数 |
shuffle | 设置为True以使数据在每个训练epoch都重新洗牌 |
collate_fn | 这个函数用来打包batch |
sampler | 定义从数据集中抽取样本的策略。 如果指定,则shuffle必须为False |
batch_sampler | 类似于采样器(sampler),但一次返回一批索引。 与batch_size,shuffle,sampler和drop_last互斥 |
num_workers | 用于数据加载的子进程数量。 0表示将在主进程中加载数据。 (默认值:0) |
collate_fn | 合并样本列表以形成小批量 |
pin_memory | 如果为True,则数据加载器在将张量返回之前将其复制到CUDA固定的内存中 |
drop_last | 如果数据集大小不能被批量大小整除,则设置为True以删除最后一个不完整的批量。 如果为False并且数据集的大小不能被批次大小整除,则最后一批将较小(默认值:False) |
timeout | 如果为正,则为从工作进程收集批次的超时值。 应始终为非负数 (默认值:0) |
worker_init_fn | 如果不为None,则在种子创建之后和数据加载之前,将在每个工作子进程上以工作ID([0,num_workers-1]中的int)作为输入来调用此方法 (默认值:无) |
__iter__()
函数, 把自己 "装进"DataLoaderIter 里面def __iter__(self):
return DataLoaderIter(self)
工作流程
__getitem__
函数获取单个的数据示例
test_image_datasets = my_dataset(testingImageDir, bndbox_test, keypointsPixeltest, keypointsWorldtest, center_test)
test_dataloaders = torch.utils.data.DataLoader(test_image_datasets, batch_size = batch_size,
shuffle = False, num_workers = 8)