PyTorch图像分类系列——流程概览

一、加载数据

1.使用torchvision.datasets的方法加载经典数据集

 在此网址查看支持哪些经典数据集:Datasets — Torchvision 0.18 documentation (pytorch.org)

data_train = torchvision.datasets.CIFAR10(root="CIFAR10", train=True, transform=ToTensor(), 
                                          target_transform=None, download=True)
data_test = torchvision.datasets.CIFAR10(root="CIFAR10", train=False, transform=ToTensor(), 
                                         target_transform=None, download=True)

 下面三个参数是所有加载经典数据集的函数共有的参数:

  • root:存储数据集的目录
  • transform:通常为对图像数据进行一系列转换操作的函数
  • transform_target:通常为对目标数据进行一系列转换操作的函数 

 2.自己收集的数据集

(1)使用列表缓存图像和标签

root = "./data"
x = []
y = []
label_to_int = {}
int_to_label = {}
for kind in os.listdir(root):
    label = len(label_to_int)
    label_to_int[kind] = label
    int_to_label[label] = kind
    kind_root = os.path.join(root, kind)
    images_path = os.listdir(kind_root)
    for img_path in images_path:
        img = Image.open(os.path.join(kind_root,img_path)).convert("RGB")
        img = torchvision.transforms.ToTensor()(img)
        x.append(img)
        y.append(label)

(2)使用自定义Dataset 动态加载

class MyDataset(Dataset):
    def __init__(self, root):
        self.root = root
        self.image_paths = []
        self.labels = []
        self.label_to_int = {}
        self.int_to_label = {}
        for kind in os.listdir(self.root):
            label = len(self.label_to_int)
            s

你可能感兴趣的:(PyTorch图像分类系列,pytorch,深度学习,分类,计算机视觉)