通过txt文件来构建自己的Dataset和DataLoader

最近在跑自己的实验室数据,但是遇到了几个问题,就是自己的数据集与一些kaggle里面的数据集不一样,那这样的画,就需要自己去构建类似kaggle里面那样的数据集,但是自己的数据集又很庞大,复制移动文件很慢,所以就使用txt文件来记录数据的路径和标签,直接通过txt文件来构建DataSet和DataLoader

1-构建自己的train.txt和val.txt

1、首先我的目录结构如下

通过txt文件来构建自己的Dataset和DataLoader_第1张图片通过txt文件来构建自己的Dataset和DataLoader_第2张图片
通过txt文件来构建自己的Dataset和DataLoader_第3张图片

2、代码如下
#段文-----2022-3-10
'''
有时候很多的数据集合很大,而且不合适把他们手工整合,所以就需要你通过txt文件去构建DataLoader
这段代码就是去构建自己的txt文件,包括train.txt和val.txt
'''

import os
import random #乱序使用随机函数
#保存地址
first_filedir='/tmp/pycharm_project_825/data/crop_0506'
#这里要注意,traindir和evaldir不能放到first_filedir里面,
traindir='/tmp/pycharm_project_825/data/crop_0506_train.txt'
evaldir='/tmp/pycharm_project_825/data/crop_0506_eval.txt'
#判断该地址下,是否含有已经存在的txt
if(os.path.exists(traindir)):  # 判断有误文件
    os.remove(traindir)  # 删除文件
if(os.path.exists(evaldir)):  # 判断有误文件
    os.remove(evaldir)  # 删除文件
fp = open(traindir,'w')
fp = open(evaldir,'w')
#构建生成txt文件函数,传入参数为(文件夹地址,训练txt地址,测试txt地址)
def make_txt(first_filedir,traindir,evaldir):
    #构建训练列表
    train_list=[]
    eval_list=[]
    filelist = os.listdir(first_filedir) #获取主目录data/crop_0506下的文件目录

    for first_dir in filelist:
        second_filedir=first_filedir+'/'+first_dir# 遍历data/crop_0506里面的文件夹
        second_filelist=os.listdir(second_filedir)# 读取data/crop_0506里面文件夹(如06031140)中的文件
        count = 0#计数
        for second_dir in second_filelist:
            count = count + 1
            third_filedir=second_filedir+'/'+second_dir#图片的真实路径
            if count % 10 == 0:#定义为每10张图片,取一张为测试图片
                #此处可以自定义*
                eval_list.append(third_filedir + ' ' + first_dir[3:4] + '\n')
            else:
                train_list.append(third_filedir + ' ' + first_dir[3:4] + '\n')
    #乱序
    random.shuffle(eval_list)
    #读写
    with open(evaldir, 'a') as f:
        for eval_image in eval_list:
            f.write(eval_image)
    #乱序,将顺序打乱,这样效果更好
    random.shuffle(train_list)
    with open(traindir, 'a') as f:
        for train_image in train_list:
            f.write(train_image)
#执行函数
make_txt(first_filedir,traindir,evaldir)
print('Over')
3、这样就构建好了对应的txt文件,内容如下

通过txt文件来构建自己的Dataset和DataLoader_第4张图片
打开txt文件可以看到如下结果:包含着图片路径和对应的标签
通过txt文件来构建自己的Dataset和DataLoader_第5张图片

2- 构建自己的Dataset

1-构建自己的dataset类
from PIL import Image
import torch
from torch.utils.data import Dataset
class MyDataSet(Dataset):
    """自定义数据集"""
    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        # RGB为彩色图片,L为灰度图片
        if img.mode != 'RGB':
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
        label = self.images_class[item]

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))
        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels

可以从上面代码看出来,我们需要传入图片路径和图片标签,那么我们就通过我们之前做的txt文件来提取图片路径和图片标签

2、构建图片的路径和图片标签的列表
def creat_path(txt):
    with open(txt, 'r') as fh:
        imgs_path = []
        imgs_label=[]
        for line in fh:
            line = line.strip('\n')  # 移除字符串首尾的换行符
            line = line.rstrip()  # 删除末尾空
            words = line.split()  # 以空格为分隔符 将字符串分成
            imgs_path.append(words[0])
            imgs_label.append(int(words[1]) if int(words[1])==1 else int(words[1])-3)
            # print(imgs_path)
    return imgs_path,imgs_label
3、调用以上的函数和类
train_imgs_path,train_imgs_label=creat_path(root + '/' + 'crop_0506_train.txt')
val_imgs_path,val_imgs_label=creat_path(root + '/' + 'crop_0506_eval.txt')

这里的transform函数也可以自己去书写

#实例化训练数据集
train_dataset = MyDataSet(images_path=train_imgs_path,
                          images_class=train_imgs_label,
                          transform=data_transform["train"])
# 实例化验证数据集
val_dataset = MyDataSet(images_path=val_imgs_path,
                        images_class=val_imgs_label,
                        transform=data_transform["val"])
4、之后就可以构建DataLoader了
train_loader = data.DataLoader(train_dataset,
                               batch_size=batch_size,
                               shuffle=True,
                               pin_memory=True,
                               num_workers=nw,
                               )

val_loader = data.DataLoader(val_dataset,
                             batch_size=batch_size,
                             shuffle=False,
                             pin_memory=True,
                             num_workers=nw,
                             )
5、之后就可以在模型里面训练了

3-代码合并

有时候很多的数据集合很大,而且不合适把他们手工整合,所以就需要你通过txt文件去构建DataLoader
这段代码就是去构建自己的txt文件,包括train.txt和val.txt
'''

import os
import random #乱序使用随机函数
#保存地址
first_filedir='/tmp/pycharm_project_825/data/crop_0506'
#这里要注意,traindir和evaldir不能放到first_filedir里面,
traindir='/tmp/pycharm_project_825/data/crop_0506_train.txt'
evaldir='/tmp/pycharm_project_825/data/crop_0506_eval.txt'
#判断该地址下,是否含有已经存在的txt
if(os.path.exists(traindir)):  # 判断有误文件
    os.remove(traindir)  # 删除文件
if(os.path.exists(evaldir)):  # 判断有误文件
    os.remove(evaldir)  # 删除文件
fp = open(traindir,'w')
fp = open(evaldir,'w')
#构建生成txt文件函数,传入参数为(文件夹地址,训练txt地址,测试txt地址)
def make_txt(first_filedir,traindir,evaldir):
    #构建训练列表
    train_list=[]
    eval_list=[]
    filelist = os.listdir(first_filedir) #获取主目录data/crop_0506下的文件目录

    for first_dir in filelist:
        second_filedir=first_filedir+'/'+first_dir# 遍历data/crop_0506里面的文件夹
        second_filelist=os.listdir(second_filedir)# 读取data/crop_0506里面文件夹(如06031140)中的文件
        count = 0#计数
        for second_dir in second_filelist:
            count = count + 1
            third_filedir=second_filedir+'/'+second_dir#图片的真实路径
            if count % 10 == 0:#定义为每10张图片,取一张为测试图片
                #此处可以自定义*
                eval_list.append(third_filedir + ' ' + first_dir[3:4] + '\n')
            else:
                train_list.append(third_filedir + ' ' + first_dir[3:4] + '\n')
    #乱序
    random.shuffle(eval_list)
    #读写
    with open(evaldir, 'a') as f:
        for eval_image in eval_list:
            f.write(eval_image)
    #乱序,将顺序打乱,这样效果更好
    random.shuffle(train_list)
    with open(traindir, 'a') as f:
        for train_image in train_list:
            f.write(train_image)
#执行函数
make_txt(first_filedir,traindir,evaldir)
print('Over')
#2
root='/tmp/pycharm_project_825/data'
def creat_path(txt):
    with open(txt, 'r') as fh:
        imgs_path = []
        imgs_label=[]
        for line in fh:
            line = line.strip('\n')  # 移除字符串首尾的换行符
            line = line.rstrip()  # 删除末尾空
            words = line.split()  # 以空格为分隔符 将字符串分成
            imgs_path.append(words[0])
            imgs_label.append(int(words[1]) if int(words[1])==1 else int(words[1])-3)
            # print(imgs_path)
    return imgs_path,imgs_label
train_imgs_path,train_imgs_label=creat_path(root + '/' + 'crop_0506_train.txt')
val_imgs_path,val_imgs_label=creat_path(root + '/' + 'crop_0506_eval.txt')
#3
from PIL import Image
import torch
from torch.utils.data import Dataset


class MyDataSet(Dataset):
    """自定义数据集"""

    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        # RGB为彩色图片,L为灰度图片
        if img.mode != 'RGB':
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
        label = self.images_class[item]

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels
#4
data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
# 实例化训练数据集
train_dataset = MyDataSet(images_path=train_imgs_path,
                              images_class=train_imgs_label,
                              transform=data_transform["train"])

# 实例化验证数据集
val_dataset = MyDataSet(images_path=val_imgs_path,
                            images_class=val_imgs_label,
                            transform=data_transform["val"])
                           
# 5
train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   pin_memory=True,
                                   num_workers=nw,
                                   collate_fn=train_dataset.collate_fn)
val_loader = data.DataLoader(val_dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=nw,
                                 collate_fn=val_dataset.collate_fn)

你可能感兴趣的:(深度学习,机器学习,pycharm,深度学习,python)