当我们使用pytorch进行深度学习模型开发时,为了使用自己的数据集训练网络,往往需要构建自己的dataloader。dataloader用于构建可迭代的数据装载器,可以使用for img, labels in dataloaders
进行可迭代对象的访问,从而提取数据用于训练与验证。
pytorch 循环这个 DataLoader 对象,将img, label加载到模型中进行训练。
自定义的LoadData三个方法是缺一不可的:
1.__init__()
,主要用来定义数据的预处理
2.__getitem__
方法,返回数据的img和label
3.__len__
方法,返回数据个数
pytorch会通过本文自定义的dataloader,从ch1,ch2,ch3三个文件夹按一定比例提取img和mask中的数据作为模型输入,从而实现多通道按比例输入不同label数据。
import os
import numpy as np
import random
import cv2
from torch.utils import data
import glob
class My_data(data.Dataset):
def __init__(self, opt):
self.opt = opt
self.__counter = 0
self.load_data()
def load_data(self):
root_path = self.opt.data_root
name_img_png = R'/img/*.png'
name_mask_png = R'/mask/*.png'
temp_char = 'ch1'
self._channel_1_img = glob.glob(os.path.join(root_path, temp_char+name_img_png))
self._channel_1_mask = glob.glob(os.path.join(root_path, temp_char+name_mask_png))
self.check_error(self._channel_1_img, self._channel_1_fake, self._channel_1_mask, temp_char)
temp_char = 'ch2'
self._channel_2_img = glob.glob(os.path.join(root_path, temp_char + name_img_png))
self._channel_2_mask = glob.glob(os.path.join(root_path, temp_char + name_mask_png))
self.check_error(self._channel_2_img, self._channel_2_fake, self._channel_2_mask, temp_char)
temp_char = 'ch3'
self._channel_3_img = glob.glob(os.path.join(root_path, temp_char + name_img_png))
self._channel_3_mask = glob.glob(os.path.join(root_path, temp_char + name_mask_png))
self.check_error(self._channel_3_img, self._channel_3_fake, self._channel_3_mask, temp_char)
self._channel_1_size = self._channel_1_img.__len__()
self._channel_1_ramdon_index = self.get_randon_index(self._channel_1_size) # 打亂index 隨機序列組合
self._channel_1_counter = 0
self._channel_2_size = self._channel_2_img.__len__()
self._channel_2_ramdon_index = self.get_randon_index(self._channel_2_size) # 打亂index 隨機序列組合
self._channel_2_counter = 0
self._channel_3_size = self._channel_3_img.__len__()
self._channel_3_ramdon_index = self.get_randon_index(self._channel_3_size) # 打亂index 隨機序列組合
self._channel_3_counter = 0
def check_error(self,img, fake, mask, path):
img.sort()
fake.sort()
mask.sort()
if(img.__len__() != fake.__len__() or img.__len__() != mask.__len__()):
print("Data Erro!",path)
def get_data(self, index, img_path, mask_path):
img = cv2.imread(img_path[index])
mask = cv2.imread(mask_path[index])
return img, mask
def get_randon_index(self, length):
index = [ i for i in range(0,length)]
randon_index = random.sample(index,index.__len__())
return randon_index
def image_swap(self, img):
image = np.swapaxes(img,0,2)
image = np.swapaxes(image,1,2)
image = image.astype(np.float32)
return image
def mask_swap(self, mask):
mask = np.swapaxes(mask,0,2)
mask = np.swapaxes(mask,1,2)
mask = mask.astype(np.long)
return mask
def __getitem__(self, index):
self.batch_counter = (int)(self.__counter % self.opt.batchSize)
# print(self.batch_counter)
if(self.batch_counter == 0 or self.batch_counter == 1):
_index = self._channel_1_ramdon_index[self._channel_1_counter]
img1, mask = self.get_data(_index, self._channel_1_img, self._channel_1_fake, self._channel_1_mask)
self._channel_1_counter = self._channel_1_counter + 1
if(self._channel_1_counter >= self._channel_1_size-1): # 已訓練圖片數超過訓練集大小,則更新隨機的index
self._channel_1_ramdon_index = self.get_randon_index(self._channel_1_size)
self._channel_1_counter = 0
if (self.batch_counter == 2 or self.batch_counter == 3):
_index = self._channel_2_ramdon_index[self._channel_2_counter]
img1, mask = self.get_data(_index, self._channel_2_img, self._channel_2_fake, self._channel_2_mask)
self._channel_2_counter = self._channel_2_counter + 1
if (self._channel_2_counter >= self._channel_2_size - 1): # 已訓練圖片數超過訓練集大小,則更新隨機的index
self._channel_2_ramdon_index = self.get_randon_index(self._channel_2_size)
self._channel_2_counter = 0
if (self.batch_counter == 4 or self.batch_counter == 5):
_index = self._channel_3_ramdon_index[self._channel_3_counter]
img1, mask = self.get_data(_index, self._channel_3_img, self._channel_3_fake, self._channel_3_mask)
self._channel_3_counter = self._channel_3_counter + 1
if (self._channel_3_counter >= self._channel_3_size - 1): # 已訓練圖片數超過訓練集大小,則更新隨機的index
self._channel_3_ramdon_index = self.get_randon_index(self._channel_3_size)
self._channel_3_counter = 0
#
# cv2.imwrite("img1.png", img1)
# cv2.imwrite("mask.png", mask)
img1 = self.image_swap(img1) / 255
mask = self.mask_swap(mask) / 255
self.__counter = self.__counter + 1
return img1, mask
def __len__(self):
return self.opt.niter * self.opt.batchSize