Pytorch图像二分类代码 AlexNet

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from PIL import Image

def check_image(path):
    try:
        im = Image.open(path)
        return True
    except:
        return False
    
img_transforms =transforms.Compose([
    #scala incoming image to the same resolution of 64×64
    transforms.Resize((64,64)),    
    #take image data and turn it into a tensor
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225] )
    ])
#providing our images are in a structure where each directory is a label 
train_data_path = "./train/"
train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=img_transforms, is_valid_file=check_image)
val_data_path = "./val/"
val_data = torchvision.datasets.ImageFolder(root=val_data_path,transform=img_transforms,

你可能感兴趣的:(Pytorch,pytorch,深度学习,机器学习)