pytorch实现CNN模型(spatial XuNet)训练模型

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import argparse
import torch.utils.data as data
from torch.utils.data import DataLoader, Dataset

import numpy as np
import cv2

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '6, 7'

class xu_net(nn.Module):
    def __init__(self):
        super(xu_net, self).__init__()

        self.conv1 = nn.Conv2d(1, 8, 5)
        self.conv2 = nn.Conv2d(8, 16, 5)
        self.conv3 = nn.Conv2d(16, 32, 1)
        self.conv4 = nn.Conv2d(32, 64, 1)
        self.conv5 = nn.Conv2d(64, 128, 1)

        self.fc = nn.Linear(128, 2)

        self.bn1 = nn.BatchNorm2d(8)
        self.bn2 = nn.BatchNorm2d(16)
        self.bn3 = nn.BatchNorm2d(32)
        self.bn4 = nn.BatchNorm2d(64)
        self.bn5 = nn.BatchNorm2d(128)

        self.avg_pool1_4 = nn.AvgPool2d(5, 2)

    def forward(self, x):
        F_0 = np.array([[[[-1 / 12.]], [[2 / 12.]], [[-2 / 12.]], [[2 / 12.]], [[-1 / 12.]]],
                               [[[2 / 12.]], [[-6 / 12.]], [[8 / 12.]], [[-6 / 12.]], [[2 / 12.]]],
                               [[[-2 / 12.]], [[8 / 12.]], [[-12 / 12.]], [[8 / 12.]], [[-2 / 12.]]],
                               [[[2 / 12.]], [[-6 / 12.]], [[8 / 12.]], [[-6 / 12.]], [[2 / 12.]]],
                               [[[-1 / 12.]], [[2 / 12.]], [[-2 / 12.]], [[2 / 12.]], [[-1 / 12.]]]], dtype=np.float)
        F_0 = F_0.reshape((1, 1, 5, 5))
        F_0 = torch.from_numpy(F_0).to(torch.device('cuda'), dtype=torch.float)
        x = x.reshape((-1, 1, 256, 256)) # x shape, batch, channel, height, width

        # group 1
        x = F.conv2d(x, F_0, padding=2)
        x = self.conv1(x) ###
        x = torch.abs(x)
        x = self.bn1(x)
        x = F.tanh(x)
        x = self.avg_pool1_4(x) ###

        # group 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.tanh(x)
        x = self.avg_pool1_4(x)

        # group 3
        x = self.conv3(x)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.avg_pool1_4(x)

        # group 4
        x = self.conv4(x)
        x = self.bn4(x)
        x = F.relu(x)
        x = self.avg_pool1_4(x)

        # group 5
        x = self.conv5(x)
        x = self.bn5(x)
        x = F.relu(x)
        x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2) # global average pooling

        # group 6
        x = self.fc(x)

        return F.log_softmax(x, dim=1)


class myDatasets(Dataset): # 重写Dataset实现自定义载入数据
    def __init__(self, cover_path, stego_path, transform=None):
        self.data = []
        self.labels = []
        for root, dirs, filenames in os.walk(cover_path):
            for filename in filenames:
                cover_coef = cv2.imread(os.path.join(root, filename), 0)
                self.data.append(cover_coef)
                self.labels.append(np.array([0]))

        for root, dirs, filenames in os.walk(stego_path):
            for filename in filenames:
                stego_coef = cv2.imread(os.path.join(root, filename), 0)
                self.data.append(stego_coef)
                self.labels.append(np.array([1]))
        self.transform = transform

    def __len__(self):
        return len(self.data)    
    
    def __getitem__(self, index):
        image = self.data[index]
        label = self.labels[index]
            
        if self.transform is not None:
            image = self.transform(image)

        return image, label


def train(args, loaders, optimizer, model, device, epoch):
    model.train()
    for i, (images, labels) in enumerate(loaders):
        images, labels = images.to(device, dtype=torch.float), labels.to(device, dtype=torch.long)
        optimizer.zero_grad()
        output = model(images)

        loss = F.nll_loss(output, labels.squeeze()) # 若output size是(N, C), 则labels的size必须是(N)
        loss.backward()
        optimizer.step()

        if i % args.log_interval == 0:
            print('train epoch {}, {} iter, loss is {}'.format(epoch, i, loss.item()))


def test(loaders, optimizer, model, device, epoch):
    model.eval()
    test_loss = 0
    correct_num = 0

    with torch.no_grad():
        for i, (images, labels) in enumerate(loaders):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            pred = output.argmax(dim=1, keepdim=True)

            test_loss += F.nll_loss(output, labels.squeeze(), reduction='sum').item()
            correct_num += pred.eq(labels.view_as(pred)).sum().item()

    test_loss /= len(loaders.dataset)

    print('test epoch {}, mean test loss is {}, predict correct number {}, accuracy {}'.format(epoch, test_loss, correct_num, float(correct_num) / len(loaders.dataset)))


def main():
    parser = argparse.ArgumentParser(description='Pytorch MNIST Example')
    parser.add_argument('--batch_size', type=int, default=36, metavar='N',
                        help='input batch size for training (default: 36)')
    parser.add_argument('--test_batch_size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=1000, metavar='N',
                        help='epochs for training (default: 1000)')
    parser.add_argument('--learning_rate', type=float, default=0.01, metavar='LR',
                        help='learning rate for training (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum for training (default: 0.5)')
    parser.add_argument('--no_cuda', action='store_true', default=False,
                        help='disables CUDA training (default: False)')
    parser.add_argument('--seed', type=int, default=1234, metavar='S',
                        help='random seed (default: 1234)')
    parser.add_argument('--log_interval', type=int, default=10, metavar='N',
                        help='show log message interval (default: 10)')
    parser.add_argument('--save_model', action='store_true', default=True,
                        help='saving model ')
    parser.add_argument('--train_cover_path', type=str, default='/program/data/cover/train', metavar='S',
                        help='training image path')
    parser.add_argument('--test_cover_path', type=str, default='/program/data/cover/test', metavar='S',
                        help='testing image path')
    parser.add_argument('--train_stego_path', type=str, default='/program/data/stc_adv_stego_5w/train', metavar='S',
                        help='training image path')
    parser.add_argument('--test_stego_path', type=str, default='/program/data/stc_adv_stego_5w/test', metavar='S',
                        help='testing image path')
    args = parser.parse_args()

    is_cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    device = torch.device('cuda' if is_cuda else 'cpu')
    kwargs = {'num_workers': 1, 'pin_memory': True} if is_cuda else {}
    transform_img = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_data = myDatasets(args.train_cover_path, args.train_stego_path, transform=transform_img)
    train_loader = data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, **kwargs)
    
    test_data = myDatasets(args.test_cover_path, args.test_stego_path, transform=transform_img)
    test_loader = data.DataLoader(test_data, batch_size=args.test_batch_size, shuffle=True, **kwargs)
    
    model = xu_net().to(device)
    optimizer = optim.SGD(model.parameters(), args.learning_rate, args.momentum)

    for epoch in range(args.epochs):
        train(args, train_loader, optimizer, model, device, epoch)
        test(test_loader, optimizer, model, device, epoch)

        if args.save_model and epoch % 20 == 0:
            torch.save(model.state_dict(), 'xuNet_cnn.pt')


if __name__ == '__main__':
        main()


XuNet论文
参考1:PyTorch Dataset and DataLoader
参考2:How do I load custom image based datasets into Pytorch for use with a CNN?
参考3: pytorch/examples 实现一个最简单的例子(比如训练mnist )

你可能感兴趣的:(机器学习,深度学习,神经网络,图像处理,pytorch)