torch.utils.data.DataLoader
中的参数:
Dataloader的处理逻辑是先通过Dataset类里面的 __getitem__
函数获取单个的数据,然后组合成batch,再使用collate_fn所指定的函数对这个batch做一些操作,比如padding啊之类的。
在NLP
中的使用主要是要重构两个两个东西,一个是dataset
,必须继承自torch.utils.data.Dataset
,内部要实现两个函数一个是__lent__
用来获取整个数据集的大小,一个是__getitem__
用来从数据集中得到一个数据片段item
。
class Dataset(torch.utils.data.Dataset):
def __init__(self, filepath=None,dataLen=None):
self.file = filepath
self.dataLen = dataLen
def __getitem__(self, index):
A,B,path,hop= linecache.getline(self.file, index+1).split('\t')
return A,B,path.split(' '),int(hop)
def __len__(self):
return self.dataLen
因为dataloader
是有batch_size
参数的,我们可以通过自定义collate_fn=myfunction
来设计数据收集的方式,意思是已经通过上面的Dataset
类中的__getitem__
函数采样了batch_size
数据,以一个包的形式传递给collate_fn
所指定的函数。
def myfunction(data):
A,B,path,hop=zip(*data)
print('A:',A," B:",B," path:",path," hop:",hop)
raise Exception('utils collate_fun 147')
return A,B,path,hop
for index,item in enumerate(dataloaders['train'])
A,B,path.hop=item
nlp任务中,经常在collate_fn指定的函数里面做padding,就是将在同一个batch中不一样长的句子padding成一样长。
def myfunction(data):
src, tgt, original_src, original_tgt = zip(*data)
src_len = [len(s) for s in src]
src_pad = torch.zeros(len(src), max(src_len)).long()
for i, s in enumerate(src):
end = src_len[i]
src_pad[i, :end] = torch.LongTensor(s[end-1::-1])
tgt_len = [len(s) for s in tgt]
tgt_pad = torch.zeros(len(tgt), max(tgt_len)).long()
for i, s in enumerate(tgt):
end = tgt_len[i]
tgt_pad[i, :end] = torch.LongTensor(s)[:end]
return src_pad, tgt_pad, \
torch.LongTensor(src_len), torch.LongTensor(tgt_len), \
original_src, original_tgt