from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import argparse
import torch.utils.data as data
from torch.utils.data import DataLoader, Dataset
import numpy as np
import cv2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '6, 7'
class xu_net(nn.Module):
def __init__(self):
super(xu_net, self).__init__()
self.conv1 = nn.Conv2d(1, 8, 5)
self.conv2 = nn.Conv2d(8, 16, 5)
self.conv3 = nn.Conv2d(16, 32, 1)
self.conv4 = nn.Conv2d(32, 64, 1)
self.conv5 = nn.Conv2d(64, 128, 1)
self.fc = nn.Linear(128, 2)
self.bn1 = nn.BatchNorm2d(8)
self.bn2 = nn.BatchNorm2d(16)
self.bn3 = nn.BatchNorm2d(32)
self.bn4 = nn.BatchNorm2d(64)
self.bn5 = nn.BatchNorm2d(128)
self.avg_pool1_4 = nn.AvgPool2d(5, 2)
def forward(self, x):
F_0 = np.array([[[[-1 / 12.]], [[2 / 12.]], [[-2 / 12.]], [[2 / 12.]], [[-1 / 12.]]],
[[[2 / 12.]], [[-6 / 12.]], [[8 / 12.]], [[-6 / 12.]], [[2 / 12.]]],
[[[-2 / 12.]], [[8 / 12.]], [[-12 / 12.]], [[8 / 12.]], [[-2 / 12.]]],
[[[2 / 12.]], [[-6 / 12.]], [[8 / 12.]], [[-6 / 12.]], [[2 / 12.]]],
[[[-1 / 12.]], [[2 / 12.]], [[-2 / 12.]], [[2 / 12.]], [[-1 / 12.]]]], dtype=np.float)
F_0 = F_0.reshape((1, 1, 5, 5))
F_0 = torch.from_numpy(F_0).to(torch.device('cuda'), dtype=torch.float)
x = x.reshape((-1, 1, 256, 256)) # x shape, batch, channel, height, width
# group 1
x = F.conv2d(x, F_0, padding=2)
x = self.conv1(x) ###
x = torch.abs(x)
x = self.bn1(x)
x = F.tanh(x)
x = self.avg_pool1_4(x) ###
# group 2
x = self.conv2(x)
x = self.bn2(x)
x = F.tanh(x)
x = self.avg_pool1_4(x)
# group 3
x = self.conv3(x)
x = self.bn3(x)
x = F.relu(x)
x = self.avg_pool1_4(x)
# group 4
x = self.conv4(x)
x = self.bn4(x)
x = F.relu(x)
x = self.avg_pool1_4(x)
# group 5
x = self.conv5(x)
x = self.bn5(x)
x = F.relu(x)
x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2) # global average pooling
# group 6
x = self.fc(x)
return F.log_softmax(x, dim=1)
class myDatasets(Dataset): # 重写Dataset实现自定义载入数据
def __init__(self, cover_path, stego_path, transform=None):
self.data = []
self.labels = []
for root, dirs, filenames in os.walk(cover_path):
for filename in filenames:
cover_coef = cv2.imread(os.path.join(root, filename), 0)
self.data.append(cover_coef)
self.labels.append(np.array([0]))
for root, dirs, filenames in os.walk(stego_path):
for filename in filenames:
stego_coef = cv2.imread(os.path.join(root, filename), 0)
self.data.append(stego_coef)
self.labels.append(np.array([1]))
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, index):
image = self.data[index]
label = self.labels[index]
if self.transform is not None:
image = self.transform(image)
return image, label
def train(args, loaders, optimizer, model, device, epoch):
model.train()
for i, (images, labels) in enumerate(loaders):
images, labels = images.to(device, dtype=torch.float), labels.to(device, dtype=torch.long)
optimizer.zero_grad()
output = model(images)
loss = F.nll_loss(output, labels.squeeze()) # 若output size是(N, C), 则labels的size必须是(N)
loss.backward()
optimizer.step()
if i % args.log_interval == 0:
print('train epoch {}, {} iter, loss is {}'.format(epoch, i, loss.item()))
def test(loaders, optimizer, model, device, epoch):
model.eval()
test_loss = 0
correct_num = 0
with torch.no_grad():
for i, (images, labels) in enumerate(loaders):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
output = model(images)
pred = output.argmax(dim=1, keepdim=True)
test_loss += F.nll_loss(output, labels.squeeze(), reduction='sum').item()
correct_num += pred.eq(labels.view_as(pred)).sum().item()
test_loss /= len(loaders.dataset)
print('test epoch {}, mean test loss is {}, predict correct number {}, accuracy {}'.format(epoch, test_loss, correct_num, float(correct_num) / len(loaders.dataset)))
def main():
parser = argparse.ArgumentParser(description='Pytorch MNIST Example')
parser.add_argument('--batch_size', type=int, default=36, metavar='N',
help='input batch size for training (default: 36)')
parser.add_argument('--test_batch_size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=1000, metavar='N',
help='epochs for training (default: 1000)')
parser.add_argument('--learning_rate', type=float, default=0.01, metavar='LR',
help='learning rate for training (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum for training (default: 0.5)')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='disables CUDA training (default: False)')
parser.add_argument('--seed', type=int, default=1234, metavar='S',
help='random seed (default: 1234)')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='show log message interval (default: 10)')
parser.add_argument('--save_model', action='store_true', default=True,
help='saving model ')
parser.add_argument('--train_cover_path', type=str, default='/program/data/cover/train', metavar='S',
help='training image path')
parser.add_argument('--test_cover_path', type=str, default='/program/data/cover/test', metavar='S',
help='testing image path')
parser.add_argument('--train_stego_path', type=str, default='/program/data/stc_adv_stego_5w/train', metavar='S',
help='training image path')
parser.add_argument('--test_stego_path', type=str, default='/program/data/stc_adv_stego_5w/test', metavar='S',
help='testing image path')
args = parser.parse_args()
is_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device('cuda' if is_cuda else 'cpu')
kwargs = {'num_workers': 1, 'pin_memory': True} if is_cuda else {}
transform_img = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_data = myDatasets(args.train_cover_path, args.train_stego_path, transform=transform_img)
train_loader = data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, **kwargs)
test_data = myDatasets(args.test_cover_path, args.test_stego_path, transform=transform_img)
test_loader = data.DataLoader(test_data, batch_size=args.test_batch_size, shuffle=True, **kwargs)
model = xu_net().to(device)
optimizer = optim.SGD(model.parameters(), args.learning_rate, args.momentum)
for epoch in range(args.epochs):
train(args, train_loader, optimizer, model, device, epoch)
test(test_loader, optimizer, model, device, epoch)
if args.save_model and epoch % 20 == 0:
torch.save(model.state_dict(), 'xuNet_cnn.pt')
if __name__ == '__main__':
main()
XuNet论文
参考1:PyTorch Dataset and DataLoader
参考2:How do I load custom image based datasets into Pytorch for use with a CNN?
参考3: pytorch/examples 实现一个最简单的例子(比如训练mnist )