根据任务需求自制数据集:Pytorch自定义数据集Dataset代码示例(有监督学习,输入输出均为图像)

自定义数据集

  • 一、使用torchvision.io读取照片
  • 二、使用PIL读取照片

一、使用torchvision.io读取照片

import numpy as np
import torch
from PIL import Image
import numpy
from matplotlib import pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms,utils
import warnings
import pandas as pd
import os
import sklearn
from skimage import io,transform
import yaml
import pathlib
from torchvision.io import image



warnings.filterwarnings("ignore")
np.printoptions(np.inf)

gpu_is_available=torch.cuda.is_available()
print("GPU is {}".format( "available" if gpu_is_available else "not available"))


def read_yaml_data():
    file_path='./environments.yaml'
    with open(file_path, 'r', encoding='utf-8') as f:
        data = yaml.load(f, Loader=yaml.FullLoader)
        # print(data)
    return data


def read_imgs_paths():
    data_paths=read_yaml_data()['data_path']
    train_hazy_dir=data_paths['train_hazy_dir']
    train_gt_dir=data_paths['train_gt_dir']
    val_hazy_dir=data_paths['val_hazy_dir']
    val_gt_dir=data_paths['val_gt_dir']
    # print(data_paths)
    train_hazy_paths=list(pathlib.Path(train_hazy_dir).glob('*'))
    train_hazy_paths=[str(i) for i in train_hazy_paths]
    train_gt_paths=list(pathlib.Path(train_gt_dir).glob('*'))
    train_gt_paths=[str(i) for i in train_gt_paths]
    val_hazy_paths=list(pathlib.Path(val_hazy_dir).glob('*'))
    val_hazy_paths=[str(i) for i in val_hazy_paths]
    val_gt_paths=list(pathlib.Path(val_gt_dir).glob('*'))
    val_gt_paths=[str(i) for i in val_gt_paths]
    train_hazy_paths.sort()
    train_gt_paths.sort()
    val_hazy_paths.sort()
    val_gt_paths.sort()
    # print(train_hazy_paths)
    # print(train_gt_paths)
    # print(val_hazy_paths)
    # print(val_gt_paths)
    return (train_hazy_paths,train_gt_paths),(val_hazy_paths,val_gt_paths)


class Dehazing_Dataset(Dataset):  # data sample: {'image':image,'landmarks':landmarks}
    def __init__(self,hazy_paths,gt_paths,transform=None):
        super(Dehazing_Dataset, self).__init__()
        self.hazy_paths=hazy_paths
        self.gt_paths=gt_paths
        self.transform=transform

    def __len__(self):  # nums of data
        return len(self.hazy_paths)

    def __getitem__(self, item):  # get a sample
        hazy_img=image.read_image(self.hazy_paths[item])/255.0  # 
        gt_img=image.read_image(self.gt_paths[item])/255.0

        if self.transform:
            hazy_img=self.transform(hazy_img)
            gt_img=self.transform(gt_img)
        return hazy_img,gt_img


def get_dataset():
    (train_hazy_paths, train_gt_paths), (val_hazy_paths, val_gt_paths) = read_imgs_paths()
    train_dataset = Dehazing_Dataset(train_hazy_paths, train_gt_paths,
                                     transform=transforms.Compose([transforms.RandomCrop(size=(256,287))]))
    val_dataset = Dehazing_Dataset(val_hazy_paths, val_gt_paths,
                                     transform=transforms.Compose([transforms.RandomCrop(size=(256,287))]))

    # for i in range(len(train_dataset)):
    #     sample=train_dataset[i]
    #     show_img(sample)

    train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True,num_workers=0)
    val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=True,num_workers=0)
    return train_dataloader,val_dataloader


def show_img(sample):
    hazy,gt=sample[0],sample[1]  # c,h,w
    hazy = hazy.permute(1, 2, 0)
    gt = gt.permute(1, 2, 0)
    plt.figure(figsize=(10,15))
    for i in range(2):
        plt.subplot(1,2,i+1)
        plt.axis('off')
        if i==0:
            plt.imshow(hazy)
        else:
            plt.imshow(gt)
    plt.show()


if __name__=='__main__':
    train_dataloader, val_dataloader=get_dataset()
    for i_batch,sample_batch in enumerate(train_dataloader):
        print(type(sample_batch))  # 
        print(sample_batch[0].size())  # torch.Size([2, 3, 256, 287])
        print(sample_batch[1].size())  # torch.Size([2, 3, 256, 287])

二、使用PIL读取照片

import numpy as np
import torch
from PIL import Image
from matplotlib import pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms,utils
import warnings
import pandas as pd
import os
import sklearn
from skimage import io,transform
import yaml
import pathlib


warnings.filterwarnings("ignore")
np.printoptions(np.inf)

gpu_is_available=torch.cuda.is_available()
print("GPU is {}".format( "available" if gpu_is_available else "not available"))


def read_yaml_data():
    file_path='./environments.yaml'
    with open(file_path, 'r', encoding='utf-8') as f:
        data = yaml.load(f, Loader=yaml.FullLoader)
        # print(data)
    return data


def read_imgs_paths():
    data_paths=read_yaml_data()['data_path']
    train_hazy_dir=data_paths['train_hazy_dir']
    train_gt_dir=data_paths['train_gt_dir']
    val_hazy_dir=data_paths['val_hazy_dir']
    val_gt_dir=data_paths['val_gt_dir']
    # print(data_paths)
    train_hazy_paths=list(pathlib.Path(train_hazy_dir).glob('*'))
    train_hazy_paths=[str(i) for i in train_hazy_paths]
    train_gt_paths=list(pathlib.Path(train_gt_dir).glob('*'))
    train_gt_paths=[str(i) for i in train_gt_paths]
    val_hazy_paths=list(pathlib.Path(val_hazy_dir).glob('*'))
    val_hazy_paths=[str(i) for i in val_hazy_paths]
    val_gt_paths=list(pathlib.Path(val_gt_dir).glob('*'))
    val_gt_paths=[str(i) for i in val_gt_paths]
    train_hazy_paths.sort()
    train_gt_paths.sort()
    val_hazy_paths.sort()
    val_gt_paths.sort()
    # print(train_hazy_paths)
    # print(train_gt_paths)
    # print(val_hazy_paths)
    # print(val_gt_paths)
    return (train_hazy_paths,train_gt_paths),(val_hazy_paths,val_gt_paths)


class Dehazing_Dataset(Dataset):  # data sample: {'image':image,'landmarks':landmarks}
    def __init__(self,hazy_paths,gt_paths,transform=None):
        super(Dehazing_Dataset, self).__init__()
        self.hazy_paths=hazy_paths
        self.gt_paths=gt_paths
        self.transform=transform

    def __len__(self):  # nums of data
        return len(self.hazy_paths)

    def __getitem__(self, item):  # get a sample
        hazy_img=io.imread(self.hazy_paths[item])/255.0  # 
        gt_img=io.imread(self.gt_paths[item])/255.0

        if self.transform:
            hazy_img=self.transform(hazy_img)
            gt_img=self.transform(gt_img)
        return hazy_img,gt_img


def get_dataset():
    (train_hazy_paths, train_gt_paths), (val_hazy_paths, val_gt_paths) = read_imgs_paths()
    train_dataset = Dehazing_Dataset(train_hazy_paths, train_gt_paths,
                                     transform=transforms.Compose([transforms.ToTensor(),transforms.RandomCrop(size=(256,287)),]))
    val_dataset = Dehazing_Dataset(val_hazy_paths, val_gt_paths,
                                     transform=transforms.Compose([transforms.ToTensor(),transforms.RandomCrop(size=(256,287)),]))

    # for i in range(len(train_dataset)):
    #     sample=train_dataset[i]
    #     show_img(sample)

    train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True,num_workers=0)
    val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=True,num_workers=0)
    return train_dataloader,val_dataloader


def show_img(sample):
    hazy,gt=sample[0],sample[1]  # c,h,w
    hazy = hazy.permute(1, 2, 0)
    gt = gt.permute(1, 2, 0)
    plt.figure(figsize=(10,15))
    for i in range(2):
        plt.subplot(1,2,i+1)
        plt.axis('off')
        if i==0:
            plt.imshow(hazy)
        else:
            plt.imshow(gt)
    plt.show()


if __name__=='__main__':
    train_dataloader, val_dataloader=get_dataset()
    for i_batch,sample_batch in enumerate(train_dataloader):
        print(type(sample_batch))  # 
        print(sample_batch[0].size())  # torch.Size([2, 3, 256, 287])
        print(sample_batch[1].size())  # torch.Size([2, 3, 256, 287])

注意:
1.Pytorch读取图像数据的集中方式,可参考:链接: https://blog.csdn.net/qq_43665602/article/details/126281393
2.使用torchvision.io和PIL两种方式读取的数据范围为[0,255],并未进行归一化,我们可根据自己的需求对其进行归一化。

  • 方式一:transform.ToTensor()会自行将数据范围归一化为[0,1];
  • 方式二:transform.Normalize(mean,std)可通过调整合适的参数值得到自己想要的归一化结果;

你可能感兴趣的:(Pytorch,深度学习,pytorch,学习,python,深度学习,人工智能)