Day.49

cbam训练:

import torch

import torch.nn as nn

import torch.optim as optim

from torchvision import datasets, transforms

from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

import numpy as np

plt.rcParams["font.family"] = ["SimHei"]

plt.rcParams['axes.unicode_minus'] = False 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"使用设备: {device}")

train_transform = transforms.Compose([

    transforms.RandomCrop(32, padding=4),

    transforms.RandomHorizontalFlip(),

    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),

    transforms.RandomRotation(15),

    transforms.ToTensor(),

    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

])

test_transform = transforms.Compose([

    transforms.ToTensor(),

    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

])

@浙大疏锦行

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)

test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

你可能感兴趣的:(机器学习,深度学习,人工智能)