Unet模型tensorflow实现代码分析--在Carvana数据集训练

1.数据集下载即代码目录结构

我们使用Kaggle上Carvana Image Masking Challenge[download]的数据集对U-net模型进行训练,下载得到的压缩包中包含以下文件:

Unet模型tensorflow实现代码分析--在Carvana数据集训练_第1张图片

 然后我们对U-net模型进行训练只需要解压,train.zip和train_masks.zip。解压到datasets,目录结构如下图所示:

Unet模型tensorflow实现代码分析--在Carvana数据集训练_第2张图片

 train目录和train_masks目录下分别有5088张原图及其对应的mask图片,其中mask图片格式为gif格式。

 data和model目录下的__init__.py为空,它所起的作用是为了方便其他目录下的文件能够将该目录作为模块导入。例如,我们要在model/unet.py下导入data/read_tfrecord模块,我们需要在unet.py中添加以下代码:

import sys
sys.path.append("../data")
from data import read_tfrecords

这里我遇到的一个问题是在unet.py中导入同个目录下的models和utils,编译器竟然报错,找不到模块,然后我google了一下,找到解决方案是将unet.py所在目录model标记为Sources Root:在pycharm中,鼠标右键点击model目录,Mark Directory as==>Sources Root

2.数据准备与数据处理

先执行scripts目录下transform_images.py:把../datasets/train/下文件以灰度图方式读取到../datasets/CarvanaImages/train/下 把../datasets/train_masks/下文件以灰度图方式读取到../datasets/CarvanaImages/train_masks/下,gif==>jpg

import os
# import cv2 # Opencv can not read GIF image directly!
from PIL import Image


if __name__ == '__main__':
    #数据集所在目录
    data_root = "../datasets"
    
    #将train目录下文件名保存进image_names列表中
    image_names = os.listdir(os.path.join(data_root, "train"))
    
    #分别处理train目录和train_masks目录下的文件
    for filename in ["train", "train_masks"]:
        for image_name in image_names:
            #train目录下的图片
            if filename is "train":
                #得到每张训练原图片的文件名路径 e.g. ../datasets/train/0cdf5b5d0ce1_01.jpg
                image_file = os.path.join(data_root, filename, image_name)
                
                #PIL的Image类读取图像
                #convert()函数,用于不同模式图像之间的转换,L表示灰度转换为灰度图像
                image = Image.open(image_file).convert("L")

                #创建../datasets/CarvanaImages/train/
                if not os.path.exists(os.path.join("../datasets/CarvanaImages", filename)):
                    os.makedirs(os.path.join("../datasets/CarvanaImages", filename))
                #保存图片路径../datasets/CarvanaImages/train_masks/image_name(0cdf5b5d0ce1_01.jpg)
                image.save(os.path.join("../datasets/CarvanaImages", filename, image_name))

            if filename is "train_masks":
                # 得到每张训练mask图片的文件名路径:e.g. ../datasets/train_masks/0cdf5b5d0ce1_01_mask.gif
                image_file = os.path.join(data_root, filename, image_name[:-4] + "_mask.gif")
                image = Image.open(image_file).convert("L")
                
                # 创建../datasets/CarvanaImages/train_mask/
                if not os.path.exists(os.path.join("../datasets/CarvanaImages", filename)):
                    os.makedirs(os.path.join("../datasets/CarvanaImages", filename))
                
                # 保存图片路径:../datasets/CarvanaImages/train_masks/image_name(0cdf5b5d0ce1_01_mask.gif==>0cdf5b5d0ce1_01_mask.jpg)
                image.save(os.path.join("../datasets/CarvanaImages", filename, 
                    image_name[:-4] + "_mask.jpg"))

 在执行scripts目录下build_tfrecords.py:将原图像及其mask图像以序列化为字符串的形式写入到一个tfrecord文件中,tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。

import os
# import glob # Can use os.listdir(data_dir) replace glob.glob(os.path.join(data_dir, "*.jpg"))
# to get every image name, do not include path.
import tensorflow as tf

if __name__ == '__main__':
    #灰度图像所在目录
    data_root = "../datasets/CarvanaImages"
    #将CarvanaImages/train目录下文件名保存进image_names列表中
    image_names = os.listdir(os.path.join(data_root, "train")) # return JPEG image names.
    
    # 创建../datasets/tfrecords/目录
    if not os.path.exists(os.path.join("../datasets", "tfrecords")):
        os.makedirs(os.path.join("../datasets", "tfrecords"))

    #tf.python_io.TFRecordWriter.__init__(path)
    # 第1步:创建文件../datasets/tfrecords/Carvana.tfrecords,为该文件创建TFRecordWriter准备写入数据
    writer = tf.python_io.TFRecordWriter(os.path.join("../datasets", "tfrecords",
        "Carvana.tfrecords"))

    for image_name in image_names:
        #得到训练原图像路径 e.g.datasets/CarvanaImages/train/0cdf5b5d0ce1_01.jpg
        image_raw_file = os.path.join(data_root, "train", image_name)
        #得到训练图像mask 路径 e.g.datasets/CarvanaImages/train_masks/0cdf5b5d0ce1_01_mask.jpg
        image_label_file = os.path.join(data_root, "train_masks", 
            image_name[:-4] + "_mask.jpg")

        #第2步:读取没有经过解码的原图及其mask(即label)
        # tf.gfile.FastGFile('path',mode).read()函数:读取没有经过解码的原始图像,如果要显示读入的图像,那就需要经过解码过程,读取的图像是一个字符串,没法显示
        # tensorflow里面提供解码的函数有两个,tf.image.decode_jepg和tf.image.decode_png分别用于解码jpg格式和png格式的图像进行解码,得到图像的像素值
        image_raw = tf.gfile.FastGFile(image_raw_file, 'rb').read() # image data type is string. 
        # read and binary.
        image_label = tf.gfile.FastGFile(image_label_file, 'rb').read()

        #tfrecord文件包含了tf.train.Example 协议缓冲区(protocol buffer,协议缓冲区包含了特征 Features)。
        # 你可以写一段代码获取你的数据, 将数据填入到Example协议缓冲区(protocol buffer),将协议缓冲区序列化为一个字符串,
        # 并且通过tf.python_io.TFRecordWriter class写入到TFRecords文件。

        #tf.train.Example(features=tf.train.Features(feature={key:value,key:value,...})
        #value类型:tfrecord支持整型、浮点数和二进制三种格式,分别是
        # tf.train.Feature(int64_list=tf.train.Int64List(value=[int_scalar]))
        # tf.train.Feature(bytes_list=tf.train.BytesList(value=[array_string_or_byte]))
        # tf.train.Feature(bytes_list=tf.train.FloatList(value=[float_scalar]))
        # write bytes to Example proto buffer.
        
        #第3步:将raw及其label填入到tfrecord文件的Example缓冲区中
        example = tf.train.Example(features=tf.train.Features(feature={
            "image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
            "image_label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_label]))
            }))

        #第4步:将Example缓冲区序列化的写入到datasets/tfrecords/Carvana.tfrecords文件中
        writer.write(example.SerializeToString()) # Serialize To String
    
    writer.close()

3. main函数代码

只关注训练阶段的代码

main.py:

import tensorflow as tf
import numpy as np
from model import unet
import cv2


def main(argv):
    #tf.app.flags.FLAGS接受命令行传递参数或者tf.app.flags定义的默认参数
    tf_flags = tf.app.flags.FLAGS

    # gpu config.
    #tf.ConfigProto()函数用在创建session的时候,用来对session进行参数配置
    config = tf.ConfigProto()
    
    #tf提供了两种控制GPU资源使用的方法,第一种方式就是限制GPU的使用率:
    config.gpu_options.per_process_gpu_memory_fraction = 0.5    #占用50%显存
    #第二种是让TensorFlow在运行过程中动态申请显存,需要多少就申请多少:
    # config.gpu_options.allow_growth = True

    if tf_flags.phase == "train":
        #使用上面定义的config设置session
        with tf.Session(config=config) as sess: 
        # when use queue to load data, not use with to define sess
            #定义Unet模型
            train_model = unet.UNet(sess, tf_flags)
            #训练Unet网络,参数:batch_size,训练迭代步......
            train_model.train(tf_flags.batch_size, tf_flags.training_steps, 
                              tf_flags.summary_steps, tf_flags.checkpoint_steps, tf_flags.save_steps)
    else:
        with tf.Session(config=config) as sess:
            # test on a image pair.
            test_model = unet.UNet(sess, tf_flags)
            #test阶段:加载checkpoint文件的数据给模型参数初始化
            test_model.load(tf_flags.checkpoint)
            image, output_masks = test_model.test()
            # return numpy ndarray.
            
            # save two images.
            filename_A = "input.png"
            filename_B = "output_masks.png"
            
            cv2.imwrite(filename_A, np.uint8(image[0].clip(0., 1.) * 255.))
            cv2.imwrite(filename_B, np.uint8(output_masks[0].clip(0., 1.) * 255.))

            # Utilize cv2.imwrite() to save images.
            print("Saved files: {}, {}".format(filename_A, filename_B))

if __name__ == '__main__':
    #tf.app.flags可以定义一些默认参数,相当于接受python文件命令行执行时后面给的的参数
    tf.app.flags.DEFINE_string("output_dir", "model_output", 
                               "checkpoint and summary directory.")
    tf.app.flags.DEFINE_string("phase", "train", 
                               "model phase: train/test.")
    tf.app.flags.DEFINE_string("training_set", "./datasets", 
                               "dataset path for training.")
    tf.app.flags.DEFINE_string("testing_set", "./datasets/test", 
                               "dataset path for testing one image pair.")
    tf.app.flags.DEFINE_integer("batch_size", 64, 
                                "batch size for training.")
    tf.app.flags.DEFINE_integer("training_steps", 100000, 
                                "total training steps.")
    tf.app.flags.DEFINE_integer("summary_steps", 100, 
                                "summary period.")
    tf.app.flags.DEFINE_integer("checkpoint_steps", 1000, 
                                "checkpoint period.")
    tf.app.flags.DEFINE_integer("save_steps", 500, 
                                "checkpoint period.")
    tf.app.flags.DEFINE_string("checkpoint", None, 
                                "checkpoint name for restoring.")
    tf.app.run(main=main)

4.Unet模型实现代码 

 网络输入shape:[batch_size,512,512,1],输出shape:[batch_size,324,324,1]

需要注意的是,在编码器的5个卷积层中,对图像进行裁剪后的图只负责与解码过程中相同分辨率的图进行拼接,池化操作是在裁剪前的图上操作的。

model目录下的models.py:

import tensorflow as tf

def Unet(name, in_data, reuse=False):
    # Not use BatchNorm or InstanceNorm.
    #确认输入非空
    assert in_data is not None
    #reuse=False:不共享变量
    with tf.variable_scope(name, reuse=reuse):
        #每经过两个卷积层就裁剪一次为了与加码器特征图合并,对卷积后的特征图池化分辨率缩小一半,不是对裁剪后的
        #size=[None,512,512,1]==>[None,510,510,64]
        conv1_1 = tf.layers.conv2d(in_data, 64, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0)) # Use Xavier init.
        # Arguments: inputs, filters, kernel_size, strides((1, 1)), padding(VALID). 
        # Appoint activation, use_bias, kernel_initializer, bias_initializer=tf.zeros_initializer().
        # In Keras's implement, kernel_initializer is he_normal, i.e. 
        # mean = 0.0, stddev = sqrt(2 / fan_in).

        #size=[None,510,510,64]==>[None,508,508,64]
        conv1_2 = tf.layers.conv2d(conv1_1, 64, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #(90, 90), (90, 90):上下左右裁掉的像素值size[None,508,508,64]==>[None,328,328,64]
        crop1 = tf.keras.layers.Cropping2D(cropping=((90, 90), (90, 90)))(conv1_2)

        #size=[None,508,508,64]==>[None,254,254,64]
        pool1 = tf.layers.max_pooling2d(conv1_2, 2, 2)

        #size=[None,254,254,64]==>[None,252,252,128]
        conv2_1 = tf.layers.conv2d(pool1, 128, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #size=[None,252,252,128]==>[None,250,250,128]
        conv2_2 = tf.layers.conv2d(conv2_1, 128, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #裁剪:size=[None,250,250,128]==>[None,168,168,128]
        crop2 = tf.keras.layers.Cropping2D(cropping=((41, 41), (41, 41)))(conv2_2)

        #size=[None,250,250,128]==>[None,125,125,128]
        pool2 = tf.layers.max_pooling2d(conv2_2, 2, 2)

        #size=[None,125,125,128]==>[None,123,123,256]
        conv3_1 = tf.layers.conv2d(pool2, 256, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #size=[None,123,123,256]==>[None,121,121,256]
        conv3_2 = tf.layers.conv2d(conv3_1, 256, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #size=[None,121,121,256]==>[None,88,88,256]
        crop3 = tf.keras.layers.Cropping2D(cropping=((16, 17), (16, 17)))(conv3_2)

        # size=[None,121,121,256]==>[None,60,60,256]
        pool3 = tf.layers.max_pooling2d(conv3_2, 2, 2)

        #size=[None,60,60,256]==>[None,58,58,512]
        conv4_1 = tf.layers.conv2d(pool3, 512, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #size=[None,58,58,512]==>[None,56,56,512]
        conv4_2 = tf.layers.conv2d(conv4_1, 512, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #[None,56,56,512]
        drop4 = tf.layers.dropout(conv4_2)
        # Arguments: inputs, rate=0.5.

        #size=[None,56,56,512]==>[None,48,48,512]
        crop4 = tf.keras.layers.Cropping2D(cropping=((4, 4), (4, 4)))(drop4)

        #[None,56,56,512]==>[None,28,28,512]
        pool4 = tf.layers.max_pooling2d(drop4, 2, 2)

        #[None,28,28,512]==>[None,26,26,1024]
        conv5_1 = tf.layers.conv2d(pool4, 1024, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #[None,26,26,1024]==>[None,24,24,1024]
        conv5_2 = tf.layers.conv2d(conv5_1, 1024, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #[None,24,24,1024]
        drop5 = tf.layers.dropout(conv5_2)



        #解码器:
        #上采样:[None,24,24,1024]==>[None,48,48,1024]
        up6_1 = tf.keras.layers.UpSampling2D(size=(2, 2))(drop5)
        '''
        Class UpSampling2D, Upsampling layer for 2D inputs. Arguments:
        size: int, or tuple of 2 integers. The upsampling factors for rows and columns.
        '''
        #每次上采样后有一个2*2的same卷积降维,[None,48,48,1024]=>[None,48,48,512]
        up6 = tf.layers.conv2d(up6_1, 512, 2, padding="SAME", activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #每次上采样后与编码器对应分辨率图像合并(concatenate)=>[None,48,48,1024]
        merge6 = tf.concat([crop4, up6], axis=3) # concat channel
        # values: A list of Tensor objects or a single Tensor.

        #[None,48,48,1024]==>[None,46,46,512]
        conv6_1 = tf.layers.conv2d(merge6, 512, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #[None, 46, 46, 1024]==[None,44,44,512]
        conv6_2 = tf.layers.conv2d(conv6_1, 512, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #上采样:[None,44,44,512]==>[None,88,88,512]
        up7_1 = tf.keras.layers.UpSampling2D(size=(2, 2))(conv6_2)

        #[None,88,88,256]
        up7 = tf.layers.conv2d(up7_1, 256, 2, padding="SAME", activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #[None,88,88,512]
        merge7 = tf.concat([crop3, up7], axis=3) # concat channel

        #[None,88,88,512]==>[None,86,86,256]
        conv7_1 = tf.layers.conv2d(merge7, 256, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #[None,86,86,256]==[None,84,84,256]
        conv7_2 = tf.layers.conv2d(conv7_1, 256, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #[None,84,84,256]==[None,168,168,256]
        up8_1 = tf.keras.layers.UpSampling2D(size=(2, 2))(conv7_2)
        #[None,168,168,256]==>[None,168,168,128]
        up8 = tf.layers.conv2d(up8_1, 128, 2, padding="SAME", activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))
        #[None,168,168,256]
        merge8 = tf.concat([crop2, up8], axis=3) # concat channel

        #[None,168,168,256]==>[None,166,166,128]
        conv8_1 = tf.layers.conv2d(merge8, 128, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))
        #[None,166,166,256]==>[None,164,164,128]
        conv8_2 = tf.layers.conv2d(conv8_1, 128, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #[None,164,164,128]==>[None,328,328,128]
        up9_1 = tf.keras.layers.UpSampling2D(size=(2, 2))(conv8_2)
        #[None,328,328,64]
        up9 = tf.layers.conv2d(up9_1, 64, 2, padding="SAME", activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))
        #[None,328,328,128]
        merge9 = tf.concat([crop1, up9], axis=3) # concat channel

        #[None,328,328,128]==>[None,326,326,64]
        conv9_1 = tf.layers.conv2d(merge9, 64, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))
        #[None,326,326,64]==>[None,324,324,64]
        conv9_2 = tf.layers.conv2d(conv9_1, 64, 3, activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))
        #[None,324,324,2]
        conv9_3 = tf.layers.conv2d(conv9_2, 2, 3, padding="SAME", activation=tf.nn.relu,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))

        #[None,324,324,2]=>[None,324,324,1]
        conv10 = tf.layers.conv2d(conv9_3, 1, 1,
            kernel_initializer = tf.contrib.layers.xavier_initializer())
            # kernel_initializer = tf.variance_scaling_initializer(scale=2.0))
        # 1 channel.

    return conv10

5. Unet类代码

model目录下的unet.py:

import os
import logging
import time
from datetime import datetime
import tensorflow as tf
from models import Unet
from utils import save_images
import sys
sys.path.append("../data")
from data import read_tfrecords
import numpy as np
import cv2
import glob

class UNet(object):
    def __init__(self, sess, tf_flags):
        self.sess = sess
        self.dtype = tf.float32
        
        #模型保存的文件夹:e.g. model_output_20190305195925/
        self.output_dir = tf_flags.output_dir
        #checkpoint文件保存目录 e.g. model_output_20190305195925/checkpoint/
        self.checkpoint_dir = os.path.join(self.output_dir, "checkpoint")
        #checkpoint文件前缀名
        self.checkpoint_prefix = "model"
        self.saver_name = "checkpoint"
        #summary文件保存的目录 e.g. model_output_20190305195925/summary/
        self.summary_dir = os.path.join(self.output_dir, "summary")

        self.is_training = (tf_flags.phase == "train")
        #初始学习率
        self.learning_rate = 0.001

        # data parameters
        #设置网络输入图像size=512*512*1
        self.image_w = 512
        self.image_h = 512 # The raw and mask image is 1918 * 1280.
        self.image_c = 1

        #输入大小:[None,512,512,1]
        self.input_data = tf.placeholder(self.dtype, [None, self.image_h, self.image_w, 
            self.image_c])
        #mask大小:[None,324,324,1]
        self.input_masks = tf.placeholder(self.dtype, [None, 324, 324, 
            self.image_c])
        # TODO: The shape of image masks. Refer to the Unet in model.py, the output image is
        # 324 * 324 * 1. But is not good.

        #定义学习率占位符
        self.lr = tf.placeholder(self.dtype)

        # train
        if self.is_training:
            #训练集目录
            self.training_set = tf_flags.training_set
            self.sample_dir = "train_results"

            #创建summary_dir,checkpoint_dir,sample_dir
            self._make_aux_dirs()

            # 定义 loss,优化器,summary,saver
            self._build_training()

            # 日志文件路径
            log_file = self.output_dir + "/Unet.log"
            logging.basicConfig(format='%(asctime)s [%(levelname)s] %(message)s',   #handler使用指明的格式化字符串:日志时间 日志级别名称 日志信息
                                filename=log_file,                                  #日志文件名
                                level=logging.DEBUG,                                #日志级别:只有级别高于DEBUG的内容才会输出
                                filemode='w')                                       #打开日志文件的模式
            #logging.getLogger()创建一个记录器
            #addHandler()添加一个StreamHandler处理器
            logging.getLogger().addHandler(logging.StreamHandler())
        else:
            # test
            self.testing_set = tf_flags.testing_set
            # build model
            self.output = self._build_test()

    def _build_training(self):
        '''
        定义self.loss,self.opt,self.summary,self.writer,self.saver
        '''
        # Unet input_data:[None,512,512,1]
        #output:[None,324,324,1]
        self.output = Unet(name="UNet", in_data=self.input_data, reuse=False)

        # loss.
        self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            labels=self.input_masks, logits=self.output))
        # self.loss = tf.reduce_mean(tf.squared_difference(self.input_masks,
        #     self.output))
        # Use Tensorflow and Keras at the same time.
        # self.loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(
        #     self.input_masks, self.output))
        
        # optimizer
        #定义Adam优化器
        self.opt = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(
            self.loss, name="opt")
        
        # summary
        tf.summary.scalar('loss', self.loss)
        
        self.summary = tf.summary.merge_all()
        # summary and checkpoint
        self.writer = tf.summary.FileWriter(
            self.summary_dir, graph=self.sess.graph)
        #最多保存10个最新的checkpoint文件
        self.saver = tf.train.Saver(max_to_keep=10, name=self.saver_name)
        self.summary_proto = tf.Summary()


    def train(self, batch_size, training_steps, summary_steps, checkpoint_steps, save_steps):
        '''
        参数:
        batch_size:
        training_steps:训练要经过多少迭代步
        summary_steps:每经过多少步就保存一次summary
        checkpoint_steps:每经过多少步就保存一次checkpoint文件
        save_steps:每经过多少步就保存一次图像
        '''
        step_num = 0
        # restore last checkpoint e.g. model_output_20180314110555/checkpoint/model-10000.index
        latest_checkpoint = tf.train.latest_checkpoint("model_output_20180314110555/checkpoint")
        

        #存在checkpoint文件
        if latest_checkpoint:
            step_num = int(os.path.basename(latest_checkpoint).split("-")[1])
            assert step_num > 0, "Please ensure checkpoint format is model-*.*."
            
            #使用最新checkpoint文件restore模型
            self.saver.restore(self.sess, latest_checkpoint)
            logging.info("{}: Resume training from step {}. Loaded checkpoint {}".format(datetime.now(), 
                step_num, latest_checkpoint))
        else:
            #不存在checkpoint文件,初始化模型参数
            self.sess.run(tf.global_variables_initializer()) # init all variables
            logging.info("{}: Init new training".format(datetime.now()))

        #定义Read_TFRecords类的对象tf_reader
        tf_reader = read_tfrecords.Read_TFRecords(filename=os.path.join(self.training_set, 
            "Carvana.tfrecords"),
            batch_size=batch_size, image_h=self.image_h, image_w=self.image_w, 
            image_c=self.image_c)

        #[batch_size,512,512,1],[batch_size,324,324,1]
        images, images_masks = tf_reader.read()
        logging.info("{}: Done init data generators".format(datetime.now()))

        #线程协调器
        self.coord = tf.train.Coordinator()
        #使用tf.train.start_queue_runners之后,才会启动填充队列的线程,这时系统就不再“停滞”。
        # 此后计算单元就可以拿到数据并进行计算,整个程序也就跑起来了
        threads = tf.train.start_queue_runners(sess=self.sess, coord=self.coord)
        try:
            # train
            c_time = time.time()
            lrval = self.learning_rate
            for c_step in range(step_num + 1, training_steps + 1):
                # 5000个step后,学习率减半
                if c_step % 5000 == 0:
                    lrval = self.learning_rate * .5
                
                batch_images, batch_images_masks = self.sess.run([images, images_masks])
                #实现反向传播需要的参数
                c_feed_dict = {
                    # TFRecord
                    self.input_data: batch_images,
                    self.input_masks: batch_images_masks,
                    self.lr: lrval
                }
                self.sess.run(self.opt, feed_dict=c_feed_dict)

                # save summary
                if c_step % summary_steps == 0:
                    #summary loss
                    c_summary = self.sess.run(self.summary, feed_dict=c_feed_dict)
                    #写summary文件
                    self.writer.add_summary(c_summary, c_step)

                    e_time = time.time() - c_time
                    time_periter = e_time / summary_steps
                    logging.info("{}: Iteration_{} ({:.4f}s/iter) {}".format(
                        datetime.now(), c_step, time_periter,
                        self._print_summary(c_summary)))        #self._print_summary(c_summary):(loss=0.665075540543)
                    c_time = time.time() # update time

                # save checkpoint
                if c_step % checkpoint_steps == 0:
                    self.saver.save(self.sess,
                        os.path.join(self.checkpoint_dir, self.checkpoint_prefix),
                        global_step=c_step)
                    logging.info("{}: Iteration_{} Saved checkpoint".format(
                        datetime.now(), c_step))

                #保存图片
                if c_step % save_steps == 0:
                    #预测的分割mask和ground truth的mask
                    _, output_masks, input_masks = self.sess.run(
                        [self.input_data, self.output, self.input_masks],
                        feed_dict=c_feed_dict)
                    #[batch_size,324,324,1]
                    save_images(None, output_masks, input_masks,
                        #self.sample_dir:train_results
                        input_path = './{}/input_{:04d}.png'.format(self.sample_dir, c_step),
                        image_path = './{}/train_{:04d}.png'.format(self.sample_dir, c_step))
        except KeyboardInterrupt:
            print('Interrupted')
            self.coord.request_stop()
        except Exception as e:
            self.coord.request_stop(e)
        finally:
            #主线程计算完成,停止所有采集数据的进程
            self.coord.request_stop()
            #等待其他线程结束
            self.coord.join(threads)
        logging.info("{}: Done training".format(datetime.now()))

    def _build_test(self):
        # network.
        output = Unet(name="UNet", in_data=self.input_data, reuse=False)

        self.saver = tf.train.Saver(max_to_keep=10, name=self.saver_name) 
        # define saver, after the network!

        return output

    def load(self, checkpoint_name=None):
        # restore checkpoint
        print("{}: Loading checkpoint...".format(datetime.now())),
        if checkpoint_name:
            checkpoint = os.path.join(self.checkpoint_dir, checkpoint_name)
            self.saver.restore(self.sess, checkpoint)
            print(" loaded {}".format(checkpoint_name))
        else:
            # restore latest model
            latest_checkpoint = tf.train.latest_checkpoint(
                self.checkpoint_dir)
            if latest_checkpoint:
                self.saver.restore(self.sess, latest_checkpoint)
                print(" loaded {}".format(os.path.basename(latest_checkpoint)))
            else:
                raise IOError(
                    "No checkpoints found in {}".format(self.checkpoint_dir))

    def test(self):
        # Test only in a image.
        image_name = glob.glob(os.path.join(self.testing_set, "*.jpg"))
        
        # In tensorflow, test image must divide 255.0.
        image = np.reshape(cv2.resize(cv2.imread(image_name[0], 0), 
            (self.image_h, self.image_w)), (1, self.image_h, self.image_w, self.image_c)) / 255.
        # OpenCV load image. the data format is BGR, w.t., (H, W, C). The default load is channel=3.

        print("{}: Done init data generators".format(datetime.now()))

        c_feed_dict = {
            self.input_data: image
        }

        output_masks = self.sess.run(
            self.output, feed_dict=c_feed_dict)

        return image, output_masks
        # image: 1 * 512 * 512 * 1
        # output_masks: 1 * 324 * 342 * 1.

    def _make_aux_dirs(self):
        if not os.path.exists(self.summary_dir):
            os.makedirs(self.summary_dir)
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
        if not os.path.exists(self.sample_dir):
            os.makedirs(self.sample_dir)

    def _print_summary(self, summary_string):
        #解析loss summary中的值
        self.summary_proto.ParseFromString(summary_string)
        result = []
        for val in self.summary_proto.value:
            result.append("({}={})".format(val.tag, val.simple_value))
        return " ".join(result)

6.Read_TFRecords类的代码

data目录read_tfrecords.py:

import tensorflow as tf


class Read_TFRecords(object):
    def __init__(self, filename, batch_size=64,
        image_h=256, image_w=256, image_c=1, num_threads=8, capacity_factor=3, min_after_dequeue=1000):
        '''
        filename: TFRecords file path.
        num_threads: TFRecords file load thread.
        capacity_factor: capacity.
        '''
        self.filename = filename
        self.batch_size = batch_size
        self.image_h = image_h
        self.image_w = image_w
        self.image_c = image_c
        self.num_threads = num_threads
        self.capacity_factor = capacity_factor
        self.min_after_dequeue = min_after_dequeue

    def read(self):
        # read a TFRecords file, return tf.train.batch/tf.train.shuffle_batch object.

        #从TFRecords文件中读取数据
        #第1步:需要用tf.train.string_input_producer生成一个文件名队列。
        filename_queue = tf.train.string_input_producer([self.filename])

        #第2步:调用tf.TFRecordReader创建读取器
        reader = tf.TFRecordReader()
        #读取文件名队列,返回serialized_example对象
        key, serialized_example = reader.read(filename_queue)
        

        #第3步:调用tf.parse_single_example操作将Example协议缓冲区(protocol buffer)解析为张量字典
        features = tf.parse_single_example(serialized_example,
            features={
                "image_raw": tf.FixedLenFeature([], tf.string),
                "image_label": tf.FixedLenFeature([], tf.string),
            })

        #第4步:对图像张量解码并进行一些处理resize,归一化...
        ## tensorflow里面提供解码的函数有两个,tf.image.decode_jepg和tf.image.decode_png分别用于解码jpg格式和png格式的图像进行解码,得到图像的像素值
        image_raw = tf.image.decode_jpeg(features["image_raw"], channels=self.image_c,
            name="decode_image")
        image_label = tf.image.decode_jpeg(features["image_label"], channels=self.image_c,
            name="decode_image")

        #将图片resize为模型指定输入大小[1918,1280]=>[512,512],[1918,1280]==>[324,324]
        if self.image_h is not None and self.image_w is not None:
            image_raw = tf.image.resize_images(image_raw, [self.image_h, self.image_w], 
                method=tf.image.ResizeMethod.BICUBIC)
            image_label = tf.image.resize_images(image_label, [324, 324], 
                method=tf.image.ResizeMethod.BICUBIC)
            # TODO: The shape of image masks. Refer to the Unet in model.py, the output image is
            # 324 * 324 * 1. But is not good.

        #像素值类型转换为tf.float32,归一化
        image_raw = tf.cast(image_raw, tf.float32) / 255.0 # convert to float32
        image_label = tf.cast(image_label, tf.float32) / 255.0 # convert to float32

        # tf.train.batch/tf.train.shuffle_batch object.
        #tf.train.shuffle_batch()该函数将会使用一个队列,函数读取一定数量的tensors送入队列,将队列中的tensor打乱,
        # 然后每次从中选取batch_size个tensors组成一个新的tensors返回出来
        #参数:
        #tensors:要入队的tensor列表
        #batch_size:表示进行一次批处理的tensors数量
        #capacity:为队列的长度,建议capacity的取值如下:min_after_dequeue + (num_threads + a small safety margin) * batch_size
        #min_after_dequeue:意思是队列中,做dequeue(取数据)的操作后,线程要保证队列中至少剩下min_after_dequeue个数据。
        #                   如果min_after_dequeue设置的过少,则即使shuffle为True,也达不到好的混合效果,过大则会占用更多的内存
        #num_threads:决定了有多少个线程进行入队操作,如果设置的超过一个线程,它们将从不同文件不同位置同时读取,可以更加充分的混合训练样本,设置num_threads的值大于1,使用多个线程在tensor_list中读取文件
        #allow_smaller_final_batch(False):当allow_smaller_final_batch为True时,如果队列中的张量数量不足batch_size,将会返回小于batch_size长度的张量,如果为False,剩下的张量会被丢弃
        # Using asynchronous queues

        #第5步:tf.train.shuffle_batch将训练集打乱,每次返回batch_size份数据
        input_data, input_masks = tf.train.shuffle_batch([image_raw, image_label],
            batch_size=self.batch_size,
            capacity=self.min_after_dequeue + self.capacity_factor * self.batch_size,
            min_after_dequeue=self.min_after_dequeue,
            num_threads=self.num_threads,
            name='images')
        
        return input_data, input_masks # return list or dictionary of tensors.

7. 图片保存代码

 model目录下的utils.py

import numpy as np
import cv2

def save_images(input, output1, output2, input_path, image_path, max_samples=4):
    #在图片宽度上concatenate=>[batch_size,324,648,1](横向)
    image = np.concatenate([output1, output2], axis=2) # concat 4D array, along width.
    #纵向concatenate的图片个数=min(max_samples,batch_size)
    if max_samples > int(image.shape[0]):
        max_samples = int(image.shape[0])
    
    image = image[0:max_samples, :, :, :]
    #[
    image = np.concatenate([image[i, :, :, :] for i in range(max_samples)], axis=0)
    # concat 3D array, along axis=0, i.e. along height. shape: (648, 648, 1).

    # save image.
    # scipy.misc.toimage(), array is 2D(gray, reshape to (H, W)) or 3D(RGB).
    # scipy.misc.toimage(image, cmin=0., cmax=1.).save(image_path) # image_path contain image path and name.
    #clip这个函数将将数组中的元素限制在a_min, a_max之间,大于a_max的就使得它等于 a_max,小于a_min,的就使得它等于a_min
    cv2.imwrite(image_path, np.uint8(image.clip(0., 1.) * 255.))

    # save input
    if input is not None:
        input_data = input[0:max_samples, :, :, :]
        #[1024,256,3]
        input_data = np.concatenate([input_data[i, :, :, :] for i in range(max_samples)], axis=0)
        cv2.imwrite(input_path, np.uint8(input_data.clip(0., 1.) * 255.))

 

你可能感兴趣的:(Unet模型tensorflow实现代码分析--在Carvana数据集训练)