pytorch自定义Dataset并使用torchvision的Transform

最近用了pytorch, 使用上比Tensorflow爽的多,尤其是在读取数据的部分,冗长而繁杂的api令人望而却步,而且由于Tensorflow不支持与numpy的无缝切换,导致难以使用现成的pandas等格式化数据读取工具,造成了很多不必要的麻烦

pytorch自定义读取数据和进行Transform的部分请见文档:
http://pytorch.org/tutorials/beginner/data_loading_tutorial.html

但是按照文档中所描述所完成的自定义Dataset只能够使用自定义的Transform步骤,而torchvision包中已经给我们提供了很多图像transform步骤的实现,为了使用这些已经实现的Transform步骤,我们可以使用如下方法定义Dataset:

class FaceLandmarkDataset(Dataset):
    def __len__(self) -> int:
        return len(self.landmarks_frame)

    def __init__(self, csv_file: str, root_dir: str, transform=None) -> None:
        super().__init__()
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __getitem__(self, index:int):
        img_name = self.landmarks_frame.ix[index, 0]
        img_path = os.path.join('./faces', img_name)
        with Image.open(img_path) as img:
            image = img.convert('RGB')
        landmarks = self.landmarks_frame.as_matrix()[index, 1:].astype('float')
        landmarks = np.reshape(landmarks,newshape=(-1,2))
        if self.transform is not None:
            image = self.transform(image)
        return image, landmarks

这几行的重点在__getitem__函数里面,这个函数内部实现了读取硬盘上的图片文件,这里一定要注意,使用PIL的Image库进行读取,然后convert到RGB值

另外,读取到的numpy的ndarray结构的数据不需要显式转换成pytorch的Tensor,后续的DataLoader会自动替你转换(当然转换一下也没差啦,反正也就是torch.from_numpy一下啦

全部代码如下:


from __future__ import print_function, division
import os
import torch
import pandas as pd
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


class FaceLandmarkDataset(Dataset):
    def __len__(self) -> int:
        return len(self.landmarks_frame)

    def __init__(self, csv_file: str, root_dir: str, transform=None) -> None:
        super().__init__()
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __getitem__(self, index:int):
        img_name = self.landmarks_frame.ix[index, 0]
        img_path = os.path.join('./faces', img_name)
        with Image.open(img_path) as img:
            image = img.convert('RGB')
        landmarks = self.landmarks_frame.as_matrix()[index, 1:].astype('float')
        landmarks = np.reshape(landmarks,newshape=(-1,2))
        if self.transform is not None:
            image = self.transform(image)
        return image, landmarks



trans = transforms.Compose(transforms = [
    transforms.RandomSizedCrop(size=128),
    transforms.ToTensor()
])

face_dataset = FaceLandmarkDataset(csv_file='faces/face_landmarks.csv', root_dir='faces', transform= trans)
loader = DataLoader(dataset = face_dataset, batch_size=4,shuffle=True,num_workers=4)

获取得到的DataLoader是个iterator,可以直接循环调用

你可能感兴趣的:(机器学习)