pytorch重采样加载数据解决样本类别不平衡问题

文章目录

  • 重采样
  • 一、构建采样器
  • 二、获取采样器中每张图片的类别
    • 1.dataloader部分有表示类别的变量
    • 2.dataloader部分图像的类别藏在图片路径里
    • 3.将采样器加到dataloader里
  • 总结


重采样

重采样其实就是建立一个权重采样器,能够在加载数据时对原始数据集按照权重进行采样。


一、构建采样器

# 数据集中,每一类的数目。
 class_sample_counts = [100, 200, 300400]
 weights = 1./ torch.tensor(class_sample_counts, dtype=torch.float)
# 这个 get_classes_for_all_imgs是关键
 train_targets = train_dataset.get_classes_for_all_imgs()
 samples_weights = weights[train_targets]

 sampler = WeightedRandomSampler(weights=samples_weights, num_samples=len(samples_weights), replacement=True)
 # when using weightedRandomSampler, it is already balanced random, so DO NOT shuffle again
 train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=setting.batch_size, shuffle=False, \
 sampler = sampler, num_workers=opt.nThreads)


二、获取采样器中每张图片的类别

1.dataloader部分有表示类别的变量

此时,上述的 train_targets = train_dataset.get_classes_for_all_imgs()需要修改成 train_targets = DataGenerator.get_classes_for_all_imgs()。代码如下(示例):

class DataGenerator(data.Dataset):
	def __init__(self, image_lines, input_shape, random=True, autoaugment_flag=True):
        super(DataGenerator, self).__init__()
        self.classes_for_all_imgs = []
        for i in range(len(self.image_lines)):
           #y表示类别,将类别存进list列表
           y = int(self.image_lines[i].split(';')[0])
           self.classes_for_all_imgs.append(y)
    def get_classes_for_all_imgs(self):
        return self.classes_for_all_imgs       

2.dataloader部分图像的类别藏在图片路径里

需要import glob模块和os模块。代码如下(示例):

def TrainDataset(data.Dataset):
    def __init__(self):
        super(TrainDataset, self).__init__()
        img_paths = sorted(glob.glob(self.data_root+'/*'))
        self.classes_for_all_imgs = []
        for img_path in img_paths:
            class_id = 0
            if img_path.find('A'):
                class_id = 0
            elif img_path.find('B'):
                class_id = 1
            elif img_path.find('C'):
                class_id = 2
            elif img_path.find('D'):
                class_id = 3
            else:
                assert 1==2
            self.classes_for_all_imgs.append(class_id)

    def __getitem__(self, index):
        ....
    def get_classes_for_all_imgs(self):
        return self.classes_for_all_imgs

3.将采样器加到dataloader里

train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, sampler = sampler,shuffle=False) 

一定shuffle设为 False


总结

方法参考于https://blog.csdn.net/Hungryof/article/details/107609877

你可能感兴趣的:(pytorch,深度学习,python)