#段文-----2022-3-10
'''
有时候很多的数据集合很大,而且不合适把他们手工整合,所以就需要你通过txt文件去构建DataLoader
这段代码就是去构建自己的txt文件,包括train.txt和val.txt
'''
import os
import random #乱序使用随机函数
#保存地址
first_filedir='/tmp/pycharm_project_825/data/crop_0506'
#这里要注意,traindir和evaldir不能放到first_filedir里面,
traindir='/tmp/pycharm_project_825/data/crop_0506_train.txt'
evaldir='/tmp/pycharm_project_825/data/crop_0506_eval.txt'
#判断该地址下,是否含有已经存在的txt
if(os.path.exists(traindir)): # 判断有误文件
os.remove(traindir) # 删除文件
if(os.path.exists(evaldir)): # 判断有误文件
os.remove(evaldir) # 删除文件
fp = open(traindir,'w')
fp = open(evaldir,'w')
#构建生成txt文件函数,传入参数为(文件夹地址,训练txt地址,测试txt地址)
def make_txt(first_filedir,traindir,evaldir):
#构建训练列表
train_list=[]
eval_list=[]
filelist = os.listdir(first_filedir) #获取主目录data/crop_0506下的文件目录
for first_dir in filelist:
second_filedir=first_filedir+'/'+first_dir# 遍历data/crop_0506里面的文件夹
second_filelist=os.listdir(second_filedir)# 读取data/crop_0506里面文件夹(如06031140)中的文件
count = 0#计数
for second_dir in second_filelist:
count = count + 1
third_filedir=second_filedir+'/'+second_dir#图片的真实路径
if count % 10 == 0:#定义为每10张图片,取一张为测试图片
#此处可以自定义*
eval_list.append(third_filedir + ' ' + first_dir[3:4] + '\n')
else:
train_list.append(third_filedir + ' ' + first_dir[3:4] + '\n')
#乱序
random.shuffle(eval_list)
#读写
with open(evaldir, 'a') as f:
for eval_image in eval_list:
f.write(eval_image)
#乱序,将顺序打乱,这样效果更好
random.shuffle(train_list)
with open(traindir, 'a') as f:
for train_image in train_list:
f.write(train_image)
#执行函数
make_txt(first_filedir,traindir,evaldir)
print('Over')
from PIL import Image
import torch
from torch.utils.data import Dataset
class MyDataSet(Dataset):
"""自定义数据集"""
def __init__(self, images_path: list, images_class: list, transform=None):
self.images_path = images_path
self.images_class = images_class
self.transform = transform
def __len__(self):
return len(self.images_path)
def __getitem__(self, item):
img = Image.open(self.images_path[item])
# RGB为彩色图片,L为灰度图片
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
label = self.images_class[item]
if self.transform is not None:
img = self.transform(img)
return img, label
@staticmethod
def collate_fn(batch):
# 官方实现的default_collate可以参考
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels
可以从上面代码看出来,我们需要传入图片路径和图片标签,那么我们就通过我们之前做的txt文件来提取图片路径和图片标签
def creat_path(txt):
with open(txt, 'r') as fh:
imgs_path = []
imgs_label=[]
for line in fh:
line = line.strip('\n') # 移除字符串首尾的换行符
line = line.rstrip() # 删除末尾空
words = line.split() # 以空格为分隔符 将字符串分成
imgs_path.append(words[0])
imgs_label.append(int(words[1]) if int(words[1])==1 else int(words[1])-3)
# print(imgs_path)
return imgs_path,imgs_label
train_imgs_path,train_imgs_label=creat_path(root + '/' + 'crop_0506_train.txt')
val_imgs_path,val_imgs_label=creat_path(root + '/' + 'crop_0506_eval.txt')
这里的transform函数也可以自己去书写
#实例化训练数据集
train_dataset = MyDataSet(images_path=train_imgs_path,
images_class=train_imgs_label,
transform=data_transform["train"])
# 实例化验证数据集
val_dataset = MyDataSet(images_path=val_imgs_path,
images_class=val_imgs_label,
transform=data_transform["val"])
train_loader = data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=nw,
)
val_loader = data.DataLoader(val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=nw,
)
有时候很多的数据集合很大,而且不合适把他们手工整合,所以就需要你通过txt文件去构建DataLoader
这段代码就是去构建自己的txt文件,包括train.txt和val.txt
'''
import os
import random #乱序使用随机函数
#保存地址
first_filedir='/tmp/pycharm_project_825/data/crop_0506'
#这里要注意,traindir和evaldir不能放到first_filedir里面,
traindir='/tmp/pycharm_project_825/data/crop_0506_train.txt'
evaldir='/tmp/pycharm_project_825/data/crop_0506_eval.txt'
#判断该地址下,是否含有已经存在的txt
if(os.path.exists(traindir)): # 判断有误文件
os.remove(traindir) # 删除文件
if(os.path.exists(evaldir)): # 判断有误文件
os.remove(evaldir) # 删除文件
fp = open(traindir,'w')
fp = open(evaldir,'w')
#构建生成txt文件函数,传入参数为(文件夹地址,训练txt地址,测试txt地址)
def make_txt(first_filedir,traindir,evaldir):
#构建训练列表
train_list=[]
eval_list=[]
filelist = os.listdir(first_filedir) #获取主目录data/crop_0506下的文件目录
for first_dir in filelist:
second_filedir=first_filedir+'/'+first_dir# 遍历data/crop_0506里面的文件夹
second_filelist=os.listdir(second_filedir)# 读取data/crop_0506里面文件夹(如06031140)中的文件
count = 0#计数
for second_dir in second_filelist:
count = count + 1
third_filedir=second_filedir+'/'+second_dir#图片的真实路径
if count % 10 == 0:#定义为每10张图片,取一张为测试图片
#此处可以自定义*
eval_list.append(third_filedir + ' ' + first_dir[3:4] + '\n')
else:
train_list.append(third_filedir + ' ' + first_dir[3:4] + '\n')
#乱序
random.shuffle(eval_list)
#读写
with open(evaldir, 'a') as f:
for eval_image in eval_list:
f.write(eval_image)
#乱序,将顺序打乱,这样效果更好
random.shuffle(train_list)
with open(traindir, 'a') as f:
for train_image in train_list:
f.write(train_image)
#执行函数
make_txt(first_filedir,traindir,evaldir)
print('Over')
#2
root='/tmp/pycharm_project_825/data'
def creat_path(txt):
with open(txt, 'r') as fh:
imgs_path = []
imgs_label=[]
for line in fh:
line = line.strip('\n') # 移除字符串首尾的换行符
line = line.rstrip() # 删除末尾空
words = line.split() # 以空格为分隔符 将字符串分成
imgs_path.append(words[0])
imgs_label.append(int(words[1]) if int(words[1])==1 else int(words[1])-3)
# print(imgs_path)
return imgs_path,imgs_label
train_imgs_path,train_imgs_label=creat_path(root + '/' + 'crop_0506_train.txt')
val_imgs_path,val_imgs_label=creat_path(root + '/' + 'crop_0506_eval.txt')
#3
from PIL import Image
import torch
from torch.utils.data import Dataset
class MyDataSet(Dataset):
"""自定义数据集"""
def __init__(self, images_path: list, images_class: list, transform=None):
self.images_path = images_path
self.images_class = images_class
self.transform = transform
def __len__(self):
return len(self.images_path)
def __getitem__(self, item):
img = Image.open(self.images_path[item])
# RGB为彩色图片,L为灰度图片
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
label = self.images_class[item]
if self.transform is not None:
img = self.transform(img)
return img, label
@staticmethod
def collate_fn(batch):
# 官方实现的default_collate可以参考
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels
#4
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
# 实例化训练数据集
train_dataset = MyDataSet(images_path=train_imgs_path,
images_class=train_imgs_label,
transform=data_transform["train"])
# 实例化验证数据集
val_dataset = MyDataSet(images_path=val_imgs_path,
images_class=val_imgs_label,
transform=data_transform["val"])
# 5
train_loader = data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=nw,
collate_fn=train_dataset.collate_fn)
val_loader = data.DataLoader(val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=nw,
collate_fn=val_dataset.collate_fn)