Tensorflow中查看训练好的模型网络结构和变量名称

在复用已训练好的模型时,进行需要查看模型网络结构和对应的变量名称。

查看Ops的信息

# after load model from file
...
# print out operations information
ops = tf.get_default_graph().get_operations()
for op in ops:
    print(op.name) # print operation name
    print('> ', op.values()) # print a list of tensors it produces

通过查看ops的输出,可以找到需要控制的tensor,以insightface为例,我们想自己控制embedding。
下面是dump出来的结构,最下面就是embedding

image_input
>  (,)
W_conv0
>  (,)
W_conv0/read
>  (,)
conv0
>  (,)
training_mode
>  ( dtype=bool>,)
BatchNorm/gamma
>  (,)
BatchNorm/gamma/read
>  (,)
...
...
embedding/Maximum/y
>  (,)
embedding/Maximum
>  (,)
embedding/Rsqrt
>  (,)
embedding
>  (,)

# 使用get_tensor_by_name,获取对应的tensor
embedding = tf.get_default_graph().get_tensor_by_name('embedding:0')

embedding具体依赖哪些tensor,需要外部出入数据,可以通过直接run,就可以看出来了。

查看vars的信息

使用get_collection可以获取相应的信息,通过设置GraphKeys即可,还有一些其他的选项,如:TRAINABLE_VARIABLES

# print out variables
for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    print(v)

下面的tf.global_variables或者tf.trainable_variables就是对上面的封装。

# print out all of variables
for v in tf.global_variables():
    print(v)

# print out all of trainable variables
for v in tf.trainable_variables():
    print(v)

你可能感兴趣的:(Tensorflow)