【目标检测】yolo_weights_convert

PASCAL VOC 数据集:https://blog.csdn.net/baidu_27643275/article/details/82754902
yolov1阅读笔记:https://blog.csdn.net/baidu_27643275/article/details/82789212
yolov1源码解析:https://blog.csdn.net/baidu_27643275/article/details/82794559
yolov2阅读笔记:https://blog.csdn.net/baidu_27643275/article/details/82859273
yolo_weights_convert:https://blog.csdn.net/baidu_27643275/article/details/83189124
yolov3源码:https://github.com/1273545169/object-detection/tree/master/yolo_v3



tensorflow中加载权重常用tf.train.Saver.restoretf.GraphDef()方法,分别从.ckpt和.pb文件中加载。可是yolo系列官方提供的权重使用.weights文件保存。故本文将简述如何将yolo3.weights转化为.ckpt或者.pb文件

yolo_weights_convert流程:
1、建立网络
2、加载权重
3、将权重保存为.ckpt或者.pb文件

1、建立网络

2、加载权重

.weight文件结构,可参考load_weights1()函数的输出结果来理解。

解析请参考:Implementing YOLO v3 in Tensorflow (TF-Slim)

# This is verry import for count,it include the version of yolo
 _ = np.fromfile(fp, dtype=np.int32, count=5)  

在yolov3中count=5,而yolov1和v2中count=4

def load_weights(var_list, weights_file):
    """
    Loads and converts pre-trained weights.
    :param var_list: list of network variables.
    :param weights_file: name of the binary file.
    :return: list of assign ops
    """
	
    with open(weights_file, "rb") as fp:
    	# This is verry import for count,it include the version of yolo
        _ = np.fromfile(fp, dtype=np.int32, count=5)  
        weights = np.fromfile(fp, dtype=np.float32)

    ptr = 0
    i = 0
    assign_ops = []
    while i < len(var_list) - 1:
        # detector/darknet-53/Conv/BatchNorm/
        var1 = var_list[i]
        var2 = var_list[i + 1]
        # do something only if we process conv layer
        if 'Conv' in var1.name.split('/')[-2]:
            # check type of next layer,BatchNorm param first of weight
            if 'BatchNorm' in var2.name.split('/')[-2]:
                # load batch norm params, It's equal to l.biases,l.scales,l.rolling_mean,l.rolling_variance
                gamma, beta, mean, var = var_list[i + 1:i + 5]
                batch_norm_vars = [beta, gamma, mean, var]
                for var in batch_norm_vars:
                    shape = var.shape.as_list()
                    num_params = np.prod(shape)
                    var_weights = weights[ptr:ptr + num_params].reshape(shape)
                    ptr += num_params
                    print(var, ptr)
                    assign_ops.append(tf.assign(var, var_weights, validate_shape=True))

                # we move the pointer by 4, because we loaded 4 variables
                i += 4
            elif 'Conv' in var2.name.split('/')[-2]:
                # load biases,not use the batch norm,So just only load biases
                bias = var2
                bias_shape = bias.shape.as_list()
                bias_params = np.prod(bias_shape)
                bias_weights = weights[ptr:ptr + bias_params].reshape(bias_shape)
                ptr += bias_params
                print(bias, ptr)
                assign_ops.append(tf.assign(bias, bias_weights, validate_shape=True))
                # we loaded 1 variable
                i += 1

            # we can load weights of conv layer
            shape = var1.shape.as_list()
            num_params = np.prod(shape)

            var_weights = weights[ptr:ptr + num_params].reshape((shape[3], shape[2], shape[0], shape[1]))
            # remember to transpose to column-major
            var_weights = np.transpose(var_weights, (2, 3, 1, 0))
            ptr += num_params
            print(var1, ptr)
            assign_ops.append(tf.assign(var1, var_weights, validate_shape=True))
            i += 1
        # yolov3中不需要以下部分 
        elif 'Local' in var1.name.split('/')[-2]:
            # load biases
            bias = var2
            bias_shape = bias.shape.as_list()
            bias_params = np.prod(bias_shape)
            bias_weights = weights[ptr:ptr + bias_params].reshape(bias_shape)
            ptr += bias_params
            print(bias, ptr)
            assign_ops.append(tf.assign(bias, bias_weights, validate_shape=True))
            i += 1

            # we can load weights of conv layer
            shape = var1.shape.as_list()
            num_params = np.prod(shape)

            var_weights = weights[ptr:ptr + num_params].reshape((shape[3], shape[2], shape[0], shape[1]))
            # remember to transpose to column-major
            var_weights = np.transpose(var_weights, (2, 3, 1, 0))
            ptr += num_params
            print(var1, ptr)
            assign_ops.append(tf.assign(var1, var_weights, validate_shape=True))
            i += 1
        elif 'Fc' in var1.name.split('/')[-2]:
            # load biases
            bias = var2
            bias_shape = bias.shape.as_list()
            bias_params = np.prod(bias_shape)
            bias_weights = weights[ptr:ptr + bias_params].reshape(bias_shape)
            ptr += bias_params
            print(bias, ptr)
            assign_ops.append(tf.assign(bias, bias_weights, validate_shape=True))
            i += 1

            # we can load weights of conv layer
            shape = var1.shape.as_list()
            num_params = np.prod(shape)

            var_weights = weights[ptr:ptr + num_params].reshape(shape[1], shape[0])
            # remember to transpose to column-major
            var_weights = np.transpose(var_weights, (1, 0))
            ptr += num_params
            print(var1, ptr)
            assign_ops.append(tf.assign(var1, var_weights, validate_shape=True))
            i += 1

    return assign_ops

3、保存为.ckpt和.pb文件

恢复模型中的权重

    with tf.variable_scope('detector'):
        detections = yolo_v3(inputs, len(classes), data_format='NHWC')
        load_ops = load_weights(tf.global_variables(scope='detector'), FLAGS.weights_file)
     
     sess = tf.Session()
	 sess.run(assign_ops)

将模型中的权重保存为.ckpt文件

  # 将权重保存为ckpt文件
  sess = tf.Session()
  saver=tf.train.Saver()
  saver.save(sess,"data/checkpoint/yolov3.ckpt")

将模型中的权重保存为.pb文件

def save_weight_to_pbfile(sess, weight_file):
    # 导出当前图的GraphDef部分即可完成从输入层到输出层的计算过程
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = tf.graph_util.\
        convert_variables_to_constants(sess, graph_def, output_node_names=["output_boxes","inputs"])
    with tf.gfile.GFile(weight_file, "wb") as f:
        f.write(output_graph_def.SerializeToString())

完整源码:https://github.com/1273545169/object-detection/tree/master/yolo_v3

参考:
https://itnext.io/implementing-yolo-v3-in-tensorflow-tf-slim-c3c55ff59dbe
https://github.com/mystic123/tensorflow-yolo-v3

你可能感兴趣的:(目标检测)