最近用了pytorch, 使用上比Tensorflow爽的多,尤其是在读取数据的部分,冗长而繁杂的api令人望而却步,而且由于Tensorflow不支持与numpy的无缝切换,导致难以使用现成的pandas等格式化数据读取工具,造成了很多不必要的麻烦
pytorch自定义读取数据和进行Transform的部分请见文档:
http://pytorch.org/tutorials/beginner/data_loading_tutorial.html
但是按照文档中所描述所完成的自定义Dataset只能够使用自定义的Transform步骤,而torchvision包中已经给我们提供了很多图像transform步骤的实现,为了使用这些已经实现的Transform步骤,我们可以使用如下方法定义Dataset:
class FaceLandmarkDataset(Dataset):
def __len__(self) -> int:
return len(self.landmarks_frame)
def __init__(self, csv_file: str, root_dir: str, transform=None) -> None:
super().__init__()
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __getitem__(self, index:int):
img_name = self.landmarks_frame.ix[index, 0]
img_path = os.path.join('./faces', img_name)
with Image.open(img_path) as img:
image = img.convert('RGB')
landmarks = self.landmarks_frame.as_matrix()[index, 1:].astype('float')
landmarks = np.reshape(landmarks,newshape=(-1,2))
if self.transform is not None:
image = self.transform(image)
return image, landmarks
这几行的重点在__getitem__
函数里面,这个函数内部实现了读取硬盘上的图片文件,这里一定要注意,使用PIL的Image库进行读取,然后convert到RGB值
另外,读取到的numpy的ndarray结构的数据不需要显式转换成pytorch的Tensor,后续的DataLoader会自动替你转换(当然转换一下也没差啦,反正也就是torch.from_numpy
一下啦
全部代码如下:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class FaceLandmarkDataset(Dataset):
def __len__(self) -> int:
return len(self.landmarks_frame)
def __init__(self, csv_file: str, root_dir: str, transform=None) -> None:
super().__init__()
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __getitem__(self, index:int):
img_name = self.landmarks_frame.ix[index, 0]
img_path = os.path.join('./faces', img_name)
with Image.open(img_path) as img:
image = img.convert('RGB')
landmarks = self.landmarks_frame.as_matrix()[index, 1:].astype('float')
landmarks = np.reshape(landmarks,newshape=(-1,2))
if self.transform is not None:
image = self.transform(image)
return image, landmarks
trans = transforms.Compose(transforms = [
transforms.RandomSizedCrop(size=128),
transforms.ToTensor()
])
face_dataset = FaceLandmarkDataset(csv_file='faces/face_landmarks.csv', root_dir='faces', transform= trans)
loader = DataLoader(dataset = face_dataset, batch_size=4,shuffle=True,num_workers=4)
获取得到的DataLoader是个iterator,可以直接循环调用