Pytorch_Tutorial自定义图像数据类以及预处理

使用pytorch训练自己模型,第一件事以及最重要的事情就是将训练集按pytorch格式导入,并完成预处理, 完整教程代码如下。主要完成人脸训练集导入,自定义数据集类,自定义图像转换预处理方法。主要参考英文教程https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings('ignore')

plt.ion()

landmarks_frame = pd.read_csv('./data/faces/faces/face_landmarks.csv')


n = 65
img_name = landmarks_frame.iloc[n,0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1,2)
print('Image name:{}'.format(img_name))
print('Landmarks shape:{}'.format(landmarks.shape))
print('First 4 Landmarks:{}'.format(landmarks[:4]))

def show_landmarks(image, landmarks):
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:,1], s=10, marker='.', c='r')
    plt.pause(0.001)
    
plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/faces/', img_name)), landmarks)
plt.show()

class Rescale(object):
    """ Rescale image in a sample to a given size
    
    Args:
        output_size (tuple or int): Desired output size. If tuple, output is matched
        to output_size. If int, smaller of image edges is matched to output_size 
        keeping aspect ratio the same.
    
    """
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size
        
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        
        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h>w:
                new_h, new_w = self.output_size*h/w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size*w/h
        else:
            new_h, new_w = self.output_size
        
        new_h, new_w = int(new_h), int(new_w)
        img = transform.resize(image, (new_h, new_w))
        
        # h and w are swapped for landmarks, x and y axes are axis 1 and 0
        landmarks = landmarks * [new_w/w, new_h/h]
        return {'image':img, 'landmarks':landmarks}
    
class RandomCrop(object):
    """Crop randomly the image in a sample
    
    Args:
        output_size(tuple or int): Desired output size. If int, square crop is make.
    
    """
    
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        # square
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size
            
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        
        h, w = image.shape[:2]
        new_h, new_w = self.output_size
        
        top = np.random.randint(0, h-new_h)
        left = np.random.randint(0,w-new_w)
        
        image = image[top:top+new_h, left:left+new_w]
        
        landmarks = landmarks - [left, top]
        return {'image':image, 'landmarks':landmarks}
    

    
class ToTensor(object):
    """ Convert ndarrays in sample to Tensors."""
    
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        
        # swap color axis
        # numpy image: H * W * C
        # torch image: C * H * W
        image = image.transpose((2,0,1))
        return {'image':torch.from_numpy(image),
               'landmarks':torch.from_numpy(landmarks)}
    
    

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset override by torch.utils.data.Dataset"""
    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.landmarks_frame)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx,0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1,2)
        sample = {'image': image, 'landmarks':landmarks}
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample
    
face_dataset = FaceLandmarksDataset(csv_file='data/faces/faces/face_landmarks.csv', root_dir='data/faces/faces/')

fig = plt.figure()

for i in range(len(face_dataset)):
    sample = face_dataset[i]
    print(i, sample['image'].shape, sample['landmarks'].shape)
    
    ax = plt.subplot(1, 4, i+1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)
    
    if i==3:
        plt.show()
        break

        
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256), RandomCrop(224)])

# apply each of transforms on sample

fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)
    
    ax = plt.subplot(1, 3, i+1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)
    
plt.show()

执行Jupyter Lab得到结果
Pytorch_Tutorial自定义图像数据类以及预处理_第1张图片

代码分析和需要注意的知识点:

1.python的路径: / 为根目录 ./ 为当前目录 …/ 为上一级目录

import pandas as pd
landmarks_frame = pd.read_csv(os.path.join(dir,filename))
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()

csv为comma seperate value文件,记录间换行,字段间为单字符间隔(一般是comma即,); iloc是基于标签的矩阵切片,查看csv得知是header字段在首行,接下来都是每一条记录,包含图片名字和各标记点(x,y数值),iloc[n,0]为图片名

  1. 继承torch的Dataset类,改写__init__, len__和__getitem
class myDataset(Dataset):
	def __init__:
	
	def __len__:

	def __getitem__:

在构造函数__init__完成对象初始化,用__len__返回正确的数据集大小, __getitem__读入图片(预处理transform也在这里完成)。注意这里是在init读取CSV,在getitem才读图像,效率更高。

  1. 定义预处理操作类Rescale, RandomCrop, ToTensor。如此避免每张图像作预处理都调用函数传递,一直需要构建函数和形参以及释放。这里作为对象调用,类似于内联函数
composed = transforms.Compose([Rescale(256), RandomCrop(224)]) 

如此完成预处理的组合处理。

你可能感兴趣的:(深度学习,Pytorch)