dataset 报错:raise keyerror (key) from err 、too many indexers

【1】原始代码:

    def __getitem__(self, index):
        wt_feature = self.wt_features[index]
        mt_feature = self.mt_features[index]
        label = self.true_ddg[index]

        # 将特征和标签转换为张量类型
        wt_feature = torch.tensor(wt_feature, dtype=torch.float32)
        mt_feature = torch.tensor(mt_feature, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.float32)

        return {"wt_feature": wt_feature, "mt_feature": mt_feature, "label": label}

在之后训练过程中,使用dataloader 在for batch 的时候出现报错:

raise keyerror (key) from err

【解释】:该报错的原因是存在超过范围的索引

【原因】:

wt_feature = self.wt_features[index]
mt_feature = self.mt_features[index]
label = self.true_ddg[index]

这里输入的wt_features mt_features 是dataframe 类型,取值应该换为以下:


wt_feature = self.wt

你可能感兴趣的:(深度学习,人工智能)