pytorch DataLoader处理不定长序列

本篇博客的目的是:
将下图这样的输入(每个tensor表示一个句子,01为句子标签):
pytorch DataLoader处理不定长序列_第1张图片
转化为下图所示的输出(batch_size=2)
元组的第一个元素为填充后的句子向量,第二个元素为句子长度,第三个元素为句子的label。
pytorch DataLoader处理不定长序列_第2张图片
为什么需要这样的处理?
    如果需要使用RNN模型处理序列数据,肯定不能将变长的序列直接输入模型,所以需要在输入前对其进行填充。这里需要注意的是,在有些情况下,输入数据不仅需要填充,并且需要在数据传送过程中记录句子的原始长度,例如在RNN中,如果句子长度差别较大,例如最大长度是50,但大多数句子长度<10,这样会导致很多句子中有很多填充的0,这会导致最后得到的hn是相同的。

第一步:建立Dataset
class SentenceDataSet(Dataset):
    def __init__(self, sent, sent_label):
        self.sent = sent
        self.sent_label = sent_label

    def __getitem__(self, item):
        return self.sent[item], self.sent_label[item]

    def __len__(self):
        return len(self.sent)


dataset = SentenceDataSet(x, y)

for xi, yi in dataset:
    print(xi, yi)

输出为:
pytorch DataLoader处理不定长序列_第3张图片

第二步:建立DataLoader

    在建立dataloader时需要注意:默认情况下dataloader会将batch中的所有数据条目的各个维度的数据直接进行拼接,由于在本博客的例子中句子序列的长度不相等,不能直接拼接,所以下面通过自己实现collate_fn函数自定义batch中的多个数据条目的组织方式。

def collate_fn(batch_data):
    """
    自定义 batch 内各个数据条目的组织方式
    :param data: 元组,第一个元素:句子序列数据,第二个元素:长度 第2维:句子标签
    :return: 填充后的句子列表、实际长度的列表、以及label列表
    """
    # batch_data 为一个batch的数据组成的列表,data中某一元素的形式如下
    # (tensor([1, 2, 3, 5]), 4, 0)
    # 后续将填充好的序列数据输入到RNN模型时需要使用pack_padded_sequence函数
    # pack_padded_sequence函数要求要按照序列的长度倒序排列
    batch_data.sort(key=lambda xi: len(xi[0]), reverse=True)
    data_length = [len(xi[0]) for xi in batch_data]
    sent_seq = [xi[0] for xi in batch_data]
    label = [xi[2] for xi in batch_data]
    padded_sent_seq = pad_sequence(sent_seq, batch_first=True, padding_value=0)
    return padded_sent_seq, data_length, torch.tensor(y, dtype=torch.float32)


# 将dataset封装为data_loader
data_loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)

for data in data_loader:
    print(data)

输出:
pytorch DataLoader处理不定长序列_第4张图片

参考链接
pytorch dataloader官方教程
博客:处理LSTM+embedding变长序列

你可能感兴趣的:(自然语言处理)