生成对抗网络——cgan

GAN:
生成对抗网络(Generative Adversarial Networks [4])主要由生成器 (generator) 和判别器 (discriminator)组成。它的原理也比较清晰,generator 负责输入随机噪声z,输出一个图片 G(z) ,而真实样本x,判别器 D 则用尽全力希望把真实样本和虚假样本区分开来。而 G 则希望产生的 G(z) 以假乱真,欺骗判别器,让其判断不出来。从而有了这么一种对抗的关系。

gan的缺点:
1.不适合处理离散形式的数据      例如——文本
2.训练不稳定
3.梯度消失
4.模式崩溃

常见的gan:
DCGAN:2、3(batchnorm)
WGAN:2(wassertein距离代替JS散度)、4
WGAN-GP:完爆以上的GAN
LSGAN:使用了最小二乘损失函数代替了GAN的损失函数,缓解了GAN训练不稳定和生成图像质量差多样性不足的问题

DCGAN:
DCGAN原理和 GAN 是一样的,DCGAN 可以理解为 GAN 和 CNN 的结合,同时 GAN 其实并不好收敛,DCGAN 在网络收敛上做了一些改进工作。比如,G 网络中采用 transposed convolutional layer 进行上采样,D 加入 stride conv 替代 pooling。G采用ReLU,最后一层采用tanh。D采用LeakyReLU。


1.cgan--mnist(手写数字生成)

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
 
#数据输入
mnist = input_data.read_data_sets('mnist/', one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
 
#返回随机值
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)
 
 
#D网络,这里是一个简单的神经网络,x是输入图片向量,y是相应的label
def discriminator(x, y):
    inputs = tf.concat(axis=1, values=[x, y])
    D_h1 = tf.nn.relu(tf.matmul(inputs, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)
 
    return D_prob, D_logit

 
#G网络
def generator(z, y):
    inputs = tf.concat(axis=1, values=[z, y])
    G_h1 = tf.nn.relu(tf.matmul(inputs, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)
 
    return G_prob
 
#噪声产生的函数
def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])
 
 
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)
 
    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
 
    return fig
 


 
#X代表输入图片,应该是28*28,但是这里没有使用CNN,y是相应的label
""" Discriminator Net model """
X = tf.placeholder(tf.float32, shape=[None, 784])
y = tf.placeholder(tf.float32, shape=[None, y_dim])
#权重,CGAN的输入是将图片输入与label concat起来,所以权重维度为784+10
D_W1 = tf.Variable(xavier_init([X_dim + y_dim, h_dim]))
D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
#第二层有h_dim个节点
D_W2 = tf.Variable(xavier_init([h_dim, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))
 
theta_D = [D_W1, D_W2, D_b1, D_b2]


#G网络参数,输入维度为Z_dim+y_dim,中间层有h_dim个节点,输出X_dim的数据
""" Generator Net model """
Z = tf.placeholder(tf.float32, shape=[None, Z_dim])
#权重
G_W1 = tf.Variable(xavier_init([Z_dim + y_dim, h_dim]))
G_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
 
G_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
G_b2 = tf.Variable(tf.zeros(shape=[X_dim]))
 
theta_G = [G_W1, G_W2, G_b1, G_b2]


#生成网络,基本和GAN一致
G_sample = generator(Z, y)
D_real, D_logit_real = discriminator(X, y)
D_fake, D_logit_fake = discriminator(G_sample, y)
#优化式
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))
#训练
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)
 
 
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#输出图片在out文件夹
if not os.path.exists('out/'):
    os.makedirs('out/')
 
i = 0
 
for it in range(1000000):
    #mb_size是网络训练时用的Batchsize,为100
    X_mb, y_mb = mnist.train.next_batch(mb_size)
    #Z_dim是noise的维度,为100
    Z_sample = sample_Z(mb_size, Z_dim)
    #交替最小化训练
    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: Z_sample, y:y_mb})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: Z_sample, y:y_mb})
    #输出训练时的参数
    if it % 1000 == 0:
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))

    # 测试
    if it % 1000 == 0:
        #n_sample 是G网络测试用的Batchsize,为16,所以输出的png图有16张
        n_sample = 16
 
        Z_sample = sample_Z(n_sample, Z_dim)#输入的噪声,尺寸为batchsize*noise维度
        y_sample = np.zeros(shape=[n_sample, y_dim])#输入的label,尺寸为batchsize*label维度
        y_sample[:, 7] = 1 #输出7
 
        samples = sess.run(G_sample, feed_dict={Z: Z_sample, y:y_sample})#G网络的输入
 
        fig = plot(samples)
        plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')#输出生成的图片
        i += 1
        plt.close(fig)

2.cgan--(由轮廓生成真实图)

image_reader.py

import os
import numpy as np
import tensorflow as tf
import cv2
 
#读取图片的函数,接收六个参数
#输入参数分别是图片名,图片路径,标签路径,图片格式,标签格式,需要调整的尺寸大小
def ImageReader(file_name, picture_path, label_path, picture_format = ".png", label_format = ".jpg", size = 256):
    picture_name = picture_path + file_name + picture_format #得到图片名称和路径
    label_name = label_path + file_name + label_format #得到标签名称和路径
    picture = cv2.imread(picture_name, 1) #读取图片
    label = cv2.imread(label_name, 1) #读取标签
    height = picture.shape[0] #得到图片的高
    width = picture.shape[1] #得到图片的宽
    picture_resize_t = cv2.resize(picture, (size, size)) #调整图片的尺寸,改变成网络输入的大小
    picture_resize = picture_resize_t / 127.5 - 1. #归一化图片
    label_resize_t = cv2.resize(label, (size, size)) #调整标签的尺寸,改变成网络输入的大小
    label_resize = label_resize_t / 127.5 - 1. #归一化标签
    return picture_resize, label_resize, height, width #返回网络输入的图片,标签,还有原图片和标签的长宽

net.py

import numpy as np
import tensorflow as tf
import math
 
#构造可训练参数
def make_var(name, shape, trainable = True):
    return tf.get_variable(name, shape, trainable = trainable)
 
#定义卷积层
def conv2d(input_, output_dim, kernel_size, stride, padding = "SAME", name = "conv2d", biased = False):
    input_dim = input_.get_shape()[-1]
    with tf.variable_scope(name):
        kernel = make_var(name = 'weights', shape=[kernel_size, kernel_size, input_dim, output_dim])
        output = tf.nn.conv2d(input_, kernel, [1, stride, stride, 1], padding = padding)
        if biased:
            biases = make_var(name = 'biases', shape = [output_dim])
            output = tf.nn.bias_add(output, biases)
        return output
 
#定义空洞卷积层
def atrous_conv2d(input_, output_dim, kernel_size, dilation, padding = "SAME", name = "atrous_conv2d", biased = False):
    input_dim = input_.get_shape()[-1]
    with tf.variable_scope(name):
        kernel = make_var(name = 'weights', shape = [kernel_size, kernel_size, input_dim, output_dim])
        output = tf.nn.atrous_conv2d(input_, kernel, dilation, padding = padding)
        if biased:
            biases = make_var(name = 'biases', shape = [output_dim])
            output = tf.nn.bias_add(output, biases)
        return output
 
#定义反卷积层
def deconv2d(input_, output_dim, kernel_size, stride, padding = "SAME", name = "deconv2d"):
    input_dim = input_.get_shape()[-1]
    input_height = int(input_.get_shape()[1])
    input_width = int(input_.get_shape()[2])
    with tf.variable_scope(name):
        kernel = make_var(name = 'weights', shape = [kernel_size, kernel_size, output_dim, input_dim])
        output = tf.nn.conv2d_transpose(input_, kernel, [1, input_height * 2, input_width * 2, output_dim], [1, 2, 2, 1], padding = "SAME")
        return output

#定义batchnorm(批次归一化)层
def batch_norm(input_, name="batch_norm"):
    with tf.variable_scope(name):
        input_dim = input_.get_shape()[-1]
        scale = tf.get_variable("scale", [input_dim], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32))
        offset = tf.get_variable("offset", [input_dim], initializer=tf.constant_initializer(0.0))
        mean, variance = tf.nn.moments(input_, axes=[1,2], keep_dims=True)
        epsilon = 1e-5
        inv = tf.rsqrt(variance + epsilon)
        normalized = (input_-mean)*inv
        output = scale*normalized + offset
        return output
 
#定义lrelu激活层
def lrelu(x, leak=0.2, name = "lrelu"):
    return tf.maximum(x, leak*x)
 
#定义生成器,采用UNet架构,主要由8个卷积层和8个反卷积层组成
def generator(image, gf_dim=64, reuse=False, name="generator"):
    input_dim = int(image.get_shape()[-1]) #获取输入通道
    dropout_rate = 0.5 #定义dropout的比例
    with tf.variable_scope(name):
        if reuse:
            tf.get_variable_scope().reuse_variables()
        else:
            assert tf.get_variable_scope().reuse is False
            
	#第一个卷积层,输出尺度[1, 128, 128, 64]
        e1 = batch_norm(conv2d(input_=image, output_dim=gf_dim, kernel_size=4, stride=2, name='g_e1_conv'), name='g_bn_e1')
	#第二个卷积层,输出尺度[1, 64, 64, 128]
        e2 = batch_norm(conv2d(input_=lrelu(e1), output_dim=gf_dim*2, kernel_size=4, stride=2, name='g_e2_conv'), name='g_bn_e2')
	#第三个卷积层,输出尺度[1, 32, 32, 256]
        e3 = batch_norm(conv2d(input_=lrelu(e2), output_dim=gf_dim*4, kernel_size=4, stride=2, name='g_e3_conv'), name='g_bn_e3')
	#第四个卷积层,输出尺度[1, 16, 16, 512]
        e4 = batch_norm(conv2d(input_=lrelu(e3), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e4_conv'), name='g_bn_e4')
	#第五个卷积层,输出尺度[1, 8, 8, 512]
        e5 = batch_norm(conv2d(input_=lrelu(e4), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e5_conv'), name='g_bn_e5')
	#第六个卷积层,输出尺度[1, 4, 4, 512]
        e6 = batch_norm(conv2d(input_=lrelu(e5), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e6_conv'), name='g_bn_e6')
	#第七个卷积层,输出尺度[1, 2, 2, 512]
        e7 = batch_norm(conv2d(input_=lrelu(e6), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e7_conv'), name='g_bn_e7')
	#第八个卷积层,输出尺度[1, 1, 1, 512]
        e8 = batch_norm(conv2d(input_=lrelu(e7), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e8_conv'), name='g_bn_e8')
 
	#第一个反卷积层,输出尺度[1, 2, 2, 512]
        d1 = deconv2d(input_=tf.nn.relu(e8), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d1')
        d1 = tf.nn.dropout(d1, dropout_rate) #随机扔掉一般的输出
        d1 = tf.concat([batch_norm(d1, name='g_bn_d1'), e7], 3)
	#第二个反卷积层,输出尺度[1, 4, 4, 512]
        d2 = deconv2d(input_=tf.nn.relu(d1), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d2')
        d2 = tf.nn.dropout(d2, dropout_rate) #随机扔掉一般的输出
        d2 = tf.concat([batch_norm(d2, name='g_bn_d2'), e6], 3)
	#第三个反卷积层,输出尺度[1, 8, 8, 512]
        d3 = deconv2d(input_=tf.nn.relu(d2), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d3')
        d3 = tf.nn.dropout(d3, dropout_rate) #随机扔掉一般的输出
        d3 = tf.concat([batch_norm(d3, name='g_bn_d3'), e5], 3)
	#第四个反卷积层,输出尺度[1, 16, 16, 512]
        d4 = deconv2d(input_=tf.nn.relu(d3), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d4')
        d4 = tf.concat([batch_norm(d4, name='g_bn_d4'), e4], 3)
	#第五个反卷积层,输出尺度[1, 32, 32, 256]
        d5 = deconv2d(input_=tf.nn.relu(d4), output_dim=gf_dim*4, kernel_size=4, stride=2, name='g_d5')
        d5 = tf.concat([batch_norm(d5, name='g_bn_d5'), e3], 3)
	#第六个反卷积层,输出尺度[1, 64, 64, 128]
        d6 = deconv2d(input_=tf.nn.relu(d5), output_dim=gf_dim*2, kernel_size=4, stride=2, name='g_d6')
        d6 = tf.concat([batch_norm(d6, name='g_bn_d6'), e2], 3)
	#第七个反卷积层,输出尺度[1, 128, 128, 64]
        d7 = deconv2d(input_=tf.nn.relu(d6), output_dim=gf_dim, kernel_size=4, stride=2, name='g_d7')
        d7 = tf.concat([batch_norm(d7, name='g_bn_d7'), e1], 3)
	#第八个反卷积层,输出尺度[1, 256, 256, 3]
        d8 = deconv2d(input_=tf.nn.relu(d7), output_dim=input_dim, kernel_size=4, stride=2, name='g_d8')
        return tf.nn.tanh(d8)
 
#定义判别器
def discriminator(image, targets, df_dim=64, reuse=False, name="discriminator"):
    with tf.variable_scope(name):
        if reuse:
            tf.get_variable_scope().reuse_variables()
        else:
            assert tf.get_variable_scope().reuse is False
        dis_input = tf.concat([image, targets], 3)

	#第1个卷积模块,输出尺度: 1*128*128*64
        h0 = lrelu(conv2d(input_ = dis_input, output_dim = df_dim, kernel_size = 4, stride = 2, name='d_h0_conv'))
	#第2个卷积模块,输出尺度: 1*64*64*128
        h1 = lrelu(batch_norm(conv2d(input_ = h0, output_dim = df_dim*2, kernel_size = 4, stride = 2, name='d_h1_conv'), name='d_bn1'))
	#第3个卷积模块,输出尺度: 1*32*32*256
        h2 = lrelu(batch_norm(conv2d(input_ = h1, output_dim = df_dim*4, kernel_size = 4, stride = 2, name='d_h2_conv'), name='d_bn2'))
	#第4个卷积模块,输出尺度: 1*32*32*512
        h3 = lrelu(batch_norm(conv2d(input_ = h2, output_dim = df_dim*8, kernel_size = 4, stride = 1, name='d_h3_conv'), name='d_bn3'))
	#最后一个卷积模块,输出尺度: 1*32*32*1
        output = conv2d(input_ = h3, output_dim = 1, kernel_size = 4, stride = 1, name='d_h4_conv')
        dis_out = tf.sigmoid(output) #在输出之前经过sigmoid层,因为需要进行log运算
        return dis_out

train.py

from __future__ import print_function
 
import argparse
from random import shuffle
import random
import os
import sys
import math
import tensorflow as tf
import glob
import cv2
 
from image_reader import *
from net import *
 
parser = argparse.ArgumentParser(description='')
 
parser.add_argument("--snapshot_dir", default='./snapshots', help="path of snapshots") #保存模型的路径
parser.add_argument("--out_dir", default='./train_out', help="path of train outputs") #训练时保存可视化输出的路径
parser.add_argument("--image_size", type=int, default=256, help="load image size") #网络输入的尺度
parser.add_argument("--random_seed", type=int, default=1234, help="random seed") #随机数种子
parser.add_argument('--base_lr', type=float, default=0.0002, help='initial learning rate for adam') #学习率
parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch')  #训练的epoch数量
parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam') #adam优化器的beta1参数
parser.add_argument("--summary_pred_every", type=int, default=200, help="times to summary.") #训练中每过多少step保存训练日志(记录一下loss值)
parser.add_argument("--write_pred_every", type=int, default=100, help="times to write.") #训练中每过多少step保存可视化结果
parser.add_argument("--save_pred_every", type=int, default=1000, help="times to save.") #训练中每过多少step保存模型(可训练参数)
parser.add_argument("--lamda_l1_weight", type=float, default=0.0, help="L1 lamda") #训练中L1_Loss前的乘数
parser.add_argument("--lamda_gan_weight", type=float, default=1.0, help="GAN lamda") #训练中GAN_Loss前的乘数
parser.add_argument("--train_picture_format", default='.jpg', help="format of training datas.") #网络训练输入的图片的格式(图片在CGAN中被当做条件)
parser.add_argument("--train_label_format", default='.jpg', help="format of training labels.") #网络训练输入的标签的格式(标签在CGAN中被当做真样本)
parser.add_argument("--train_picture_path", default='img/', help="path of training datas.") #网络训练输入的图片路径
parser.add_argument("--train_label_path", default='img2/', help="path of training labels.") #网络训练输入的标签路径
 
args = parser.parse_args() #用来解析命令行参数
EPS = 1e-12 #EPS用于保证log函数里面的参数大于零
 
def save(saver, sess, logdir, step): #保存模型的save函数
   model_name = 'model' #保存的模型名前缀
   checkpoint_path = os.path.join(logdir, model_name) #模型的保存路径与名称
   if not os.path.exists(logdir): #如果路径不存在即创建
      os.makedirs(logdir)
   saver.save(sess, checkpoint_path, global_step=step) #保存模型
   print('The checkpoint has been created.')
 
def cv_inv_proc(img): #cv_inv_proc函数将读取图片时归一化的图片还原成原图
    img_rgb = (img + 1.) * 127.5
    return img_rgb.astype(np.float32) #返回bgr格式的图像,方便cv2写图像
 
def get_write_picture(picture, gen_label, label, height, width): #get_write_picture函数得到训练过程中的可视化结果
    picture_image = cv_inv_proc(picture) #还原输入的图像
    gen_label_image = cv_inv_proc(gen_label[0]) #还原生成的样本
    label_image = cv_inv_proc(label) #还原真实的样本(标签)
    inv_picture_image = cv2.resize(picture_image, (width, height)) #还原图像的尺寸
    inv_gen_label_image = cv2.resize(gen_label_image, (width, height)) #还原生成的样本的尺寸
    inv_label_image = cv2.resize(label_image, (width, height)) #还原真实的样本的尺寸
    output = np.concatenate((inv_picture_image, inv_gen_label_image, inv_label_image), axis=1) #把他们拼起来
    return output
    
def l1_loss(src, dst): #定义l1_loss
    return tf.reduce_mean(tf.abs(src - dst))
 
def main(): #训练程序的主函数
    if not os.path.exists(args.snapshot_dir): #如果保存模型参数的文件夹不存在则创建
        os.makedirs(args.snapshot_dir)
    if not os.path.exists(args.out_dir): #如果保存训练中可视化输出的文件夹不存在则创建
        os.makedirs(args.out_dir)

    #获取图片列表,构造一个图片和一个标签
    train_picture_list = glob.glob(os.path.join(args.train_picture_path, "*")) #得到训练输入图像路径名称列表
    tf.set_random_seed(args.random_seed) #初始一下随机数
    train_picture = tf.placeholder(tf.float32,shape=[1, args.image_size, args.image_size, 3],name='train_picture') #输入的训练图像
    train_label = tf.placeholder(tf.float32,shape=[1, args.image_size, args.image_size, 3],name='train_label') #输入的与训练图像匹配的标签
    
    #生成器生成图片,判别器输出对真实标签(train_label)、对生成标签(gen_label)的判断结果
    gen_label = generator(image=train_picture, gf_dim=64, reuse=False, name='generator') #得到生成器的输出
    dis_real = discriminator(image=train_picture, targets=train_label, df_dim=64, reuse=False, name="discriminator") #判别器返回的对真实标签的判别结果
    dis_fake = discriminator(image=train_picture, targets=gen_label, df_dim=64, reuse=True, name="discriminator") #判别器返回的对生成(虚假的)标签判别结果
    
    #计算生成器的损失
    gen_loss_GAN = tf.reduce_mean(-tf.log(dis_fake + EPS)) #计算生成器损失中的GAN_loss部分
    gen_loss_L1 = tf.reduce_mean(l1_loss(gen_label, train_label)) #计算生成器损失中的L1_loss部分
    gen_loss = gen_loss_GAN * args.lamda_gan_weight + gen_loss_L1 * args.lamda_l1_weight #计算生成器的loss
    #计算判别器的损失
    dis_loss = tf.reduce_mean(-(tf.log(dis_real + EPS) + tf.log(1 - dis_fake + EPS))) #计算判别器的loss
 
    gen_loss_sum = tf.summary.scalar("gen_loss", gen_loss) #记录生成器loss的日志
    dis_loss_sum = tf.summary.scalar("dis_loss", dis_loss) #记录判别器loss的日志
    #把日志输出到文件夹-snapshot_dir
    summary_writer = tf.summary.FileWriter(args.snapshot_dir, graph=tf.get_default_graph()) #日志记录器
 
    g_vars = [v for v in tf.trainable_variables() if 'generator' in v.name] #所有生成器的可训练参数
    d_vars = [v for v in tf.trainable_variables() if 'discriminator' in v.name] #所有判别器的可训练参数
 
    #计算梯度
    d_optim = tf.train.AdamOptimizer(args.base_lr, beta1=args.beta1) #判别器训练器
    g_optim = tf.train.AdamOptimizer(args.base_lr, beta1=args.beta1) #生成器训练器

    d_grads_and_vars = d_optim.compute_gradients(dis_loss, var_list=d_vars) #计算判别器参数梯度
    d_train = d_optim.apply_gradients(d_grads_and_vars) #更新判别器参数
    g_grads_and_vars = g_optim.compute_gradients(gen_loss, var_list=g_vars) #计算生成器参数梯度
    g_train = g_optim.apply_gradients(g_grads_and_vars) #更新生成器参数
    
    #设置GPU
    train_op = tf.group(d_train, g_train) #train_op表示了参数更新操作
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True #设定显存不超量使用

    #定义会话
    sess = tf.Session(config=config) #新建会话层
    init = tf.global_variables_initializer() #参数初始化器
    sess.run(init) #初始化所有可训练参数
 
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=50) #模型保存器
 
    counter = 0 #counter记录训练步数

    for epoch in range(args.epoch): #训练epoch数
        shuffle(train_picture_list) #每训练一个epoch,就打乱一下输入的顺序
        for step in range(len(train_picture_list)): #每个训练epoch中的训练step数
            counter += 1
            picture_name, _ = os.path.splitext(os.path.basename(train_picture_list[step])) #获取不包含路径和格式的输入图片名称
        
            #读取一张训练图片,一张训练标签,以及相应的高和宽
            picture_resize, label_resize, picture_height, picture_width = ImageReader(file_name=picture_name, 
                picture_path=args.train_picture_path, label_path=args.train_label_path, 
                picture_format = args.train_picture_format, label_format = args.train_label_format, 
                size = args.image_size)

            #给图片和标签填充维度
            batch_picture = np.expand_dims(np.array(picture_resize).astype(np.float32), axis = 0) #填充维度
            batch_label = np.expand_dims(np.array(label_resize).astype(np.float32), axis = 0) #填充维度
            print(batch_picture.shape, type(batch_label))
            #喂入数据,计算各个损失
            feed_dict = { train_picture : batch_picture, train_label : batch_label } #构造feed_dict
            gen_loss_value, dis_loss_value, _ = sess.run([gen_loss, dis_loss, train_op], feed_dict=feed_dict) #得到每个step中的生成器和判别器loss

            #按步数保存各种数据:模型、日志、可视化结果
            if counter % args.save_pred_every == 100: #每过save_pred_every次保存模型
                save(saver, sess, args.snapshot_dir, counter)

            if counter % args.summary_pred_every == 100: #每过summary_pred_every次保存训练日志
                gen_loss_sum_value, discriminator_sum_value = sess.run([gen_loss_sum, dis_loss_sum], feed_dict=feed_dict)
                summary_writer.add_summary(gen_loss_sum_value, counter)
                summary_writer.add_summary(discriminator_sum_value, counter)

            if counter % args.write_pred_every == 100: #每过write_pred_every次写一下训练的可视化结果
                gen_label_value = sess.run(gen_label, feed_dict=feed_dict) #run出生成器的输出
                write_image = get_write_picture(picture_resize, gen_label_value, label_resize, picture_height, picture_width) #得到训练的可视化结果
                write_image_name = args.out_dir + "/out"+ str(counter) + ".png" #待保存的训练可视化结果路径与名称
                cv2.imwrite(write_image_name, write_image) #保存训练的可视化结果

            print('epoch {:d} step {:d} \t gen_loss = {:.3f}, dis_loss = {:.3f}'.format(epoch, step, gen_loss_value, dis_loss_value))
    
if __name__ == '__main__':
    main()

evaluate.py

import argparse
import sys
import math
import tensorflow as tf
import numpy as np
import glob
import cv2
 
from image_reader import *
from net import *
 
parser = argparse.ArgumentParser(description='')
 
parser.add_argument("--test_picture_path", default='img/', help="path of test datas.")#网络测试输入的图片路径
parser.add_argument("--test_label_path", default='img2/', help="path of test datas.") #网络测试输入的标签路径
parser.add_argument("--image_size", type=int, default=256, help="load image size") #网络输入的尺度
parser.add_argument("--test_picture_format", default='.jpg', help="format of test pictures.") #网络测试输入的图片的格式
parser.add_argument("--test_label_format", default='.jpg', help="format of test labels.") #网络测试时读取的标签的格式
parser.add_argument("--snapshots", default='snapshots/',help="Path of Snapshots") #读取训练好的模型参数的路径
parser.add_argument("--out_dir", default='test_output/',help="Output Folder") #保存网络测试输出图片的路径
 
args = parser.parse_args() #用来解析命令行参数
 
 
def cv_inv_proc(img): #cv_inv_proc函数将读取图片时归一化的图片还原成原图
    img_rgb = (img + 1.) * 127.5
    return img_rgb.astype(np.float32) #返回bgr格式的图像,方便cv2写图像
 
 
def get_write_picture(picture, gen_label, label, height, width): #get_write_picture函数得到网络测试的结果
    picture_image = cv_inv_proc(picture) #还原输入的图像
    gen_label_image = cv_inv_proc(gen_label[0]) #还原生成的结果
    label_image = cv_inv_proc(label) #还原读取的标签
    inv_picture_image = cv2.resize(picture_image, (width, height)) #将输入图像还原到原大小
    inv_gen_label_image = cv2.resize(gen_label_image, (width, height)) #将生成的结果还原到原大小
    inv_label_image = cv2.resize(label_image, (width, height)) #将标签还原到原大小
    output = np.concatenate((inv_picture_image, inv_gen_label_image, inv_label_image), axis=1) #拼接得到输出结果
    return output
 
 
def main():
    if not os.path.exists(args.out_dir): #如果保存测试结果的文件夹不存在则创建
        os.makedirs(args.out_dir)
 
    test_picture_list = glob.glob(os.path.join(args.test_picture_path, "*")) #得到测试输入图像路径名称列表
    test_picture = tf.placeholder(tf.float32, shape=[1, 256, 256, 3], name='test_picture') #测试输入的图像
 
    gen_label = generator(image=test_picture, gf_dim=64, reuse=False, name='generator') #得到生成器的生成结果
 
    restore_var = [v for v in tf.global_variables() if 'generator' in v.name] #需要载入的已训练的模型参数
 
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True #设定显存不超量使用
    sess = tf.Session(config=config) #建立会话层
 
    saver = tf.train.Saver(var_list=restore_var, max_to_keep=1) #导入模型参数时使用
    checkpoint = tf.train.latest_checkpoint(args.snapshots) #读取模型参数
    saver.restore(sess, checkpoint) #导入模型参数
 
    for step in range(len(test_picture_list)):
        picture_name, _ = os.path.splitext(os.path.basename(test_picture_list[step])) #得到一张网络测试的输入图像名字
	#读取一张测试图片,一张标签,以及相应的高和宽
        picture_resize, label_resize, picture_height, picture_width = ImageReader(file_name=picture_name,
                                                                                  picture_path=args.test_picture_path,
                                                                                  label_path=args.test_label_path,
                                                                                  picture_format=args.test_picture_format,
                                                                                  label_format=args.test_label_format,
                                                                                  size=args.image_size)

        batch_picture = np.expand_dims(np.array(picture_resize).astype(np.float32), axis=0) #填充维度
        feed_dict = {test_picture: batch_picture} #构造feed_dict
        gen_label_value = sess.run(gen_label, feed_dict=feed_dict) #得到生成结果
        
        write_image = get_write_picture(picture_resize, gen_label_value, label_resize, picture_height, picture_width) #得到一张需要存的图像
        write_image_name = args.out_dir + picture_name + ".png" #为上述的图像构造保存路径与文件名
        cv2.imwrite(write_image_name, write_image) #保存测试结果
        print('step {:d}'.format(step))
 
if __name__ == '__main__':
    main()

你可能感兴趣的:(生成对抗网络——GAN,深度学习,gan)