pytorch数据集构建datasets

pytorch内置了Dataset类,可以通过自己的任意数据对训练、测试数据集进行构建。
在使用时需要自己定义类,需要拥有__len__、 __ getitem__类,用于获取数据集的长度,以及对数据集内容进行迭代。

以下即是一个自主构建数据集类的例子。

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

__ init__:初始化函数用于对数据集的某些参数进行定义。
__ len__:用于返回数据集的长度。
__ getitem__:用于对数据集进行迭代。

以下为调用方式:

face_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
                                    root_dir='faces/')

你可能感兴趣的:(pytorch数据集构建datasets)