深度学习数据集加载

 数据集结构

E:\Mytest\test20250622\pythonProject\dataset
├── rose
│   ├── rose1.jpg
│   ├── rose2.jpg
│   └── ...
└── sunflower
    ├── sunflower1.jpg
    ├── sunflower2.jpg
    └── ...

主要只有的两个类

from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader

 加载示例

data.py

import os
from torch.utils.data import Dataset
from PIL import Image


class MyDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []

        """
            E:\Mytest\test20250622\pythonProject\dataset
            ├── rose
            │   ├── rose1.jpg
            │   ├── rose2.jpg
            │   └── ...
            └── sunflower
                ├── sunflower1.jpg
                ├── sunflower2.jpg
                └── ...
                
        os.listdir(self.root_dir),会返回E:\Mytest\test20250622\pythonProject\dataset
        目录下的所有子目录,例如rose, sunflower
        """

        classes = os.listdir(self.root_dir)
        for index, className in enumerate(classes):
            class_dir = os.path.join(self.root_dir, className)
            if os.path.isdir(class_dir):
                for img_name in os.listdir(class_dir):
                    img_path = os.path.join(class_dir, img_name)
                    self.images.append(img_path)
                    self.labels.append(index)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

main.py

from data import MyDataset
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# 定义图像变换
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 实例化自定义的Dataset
dataset = MyDataset(root_dir='E:\\Mytest\\test20250622\\pythonProject\\dataset', transform=transform)

# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=3, shuffle=True)

# 遍历dataloader
for i, (images, labels) in enumerate(dataloader):
    print(f"Batch {i + 1}, Image shape: {images.shape}, Labels: {labels}")

    """
        Batch 1, Image shape: torch.Size([3, 3, 224, 224]), Labels: tensor([1, 0, 1])
        Batch 2, Image shape: torch.Size([2, 3, 224, 224]), Labels: tensor([1, 0])
    """

    # 将张量转换为可显示的图像格式 (H x W x C)
    # image = images.squeeze(0).permute(1, 2, 0)  # 去掉batch维度,并将通道放到最后

    # 显示图像
    # plt.imshow(image)
    # plt.title(f"Label: {labels.item()}")
    # plt.axis('off')
    # plt.show()

    # 可选:只显示前几张图
    if i >= 4:
        break

 

你可能感兴趣的:(深度学习,人工智能)