假设我们定义了一个keras模型,并且使用它的save_weights函数保存了一些参数.现在我们只定义这个模型的一部分,并且使用load_weights去加载我们保存的这个完整的模型,会发生什么?
首先看源代码,load_weights实际上是调用了tensorflow_core/python/keras/engine/network.py文件中Network类的load_weights函数,而在这个函数中,分别对save_format为tf和h5两种类型的文件做了不同的处理,我们这里不是h5文件,那么显然就应该是tf文件.源代码就不贴了,根据源代码,我猜测读取tf文件的函数是tensorflow_core/python/training/tracking/util.py中的TrackableSaver类的restore函数.我在这个函数的doc中发现了这样一句话.
If the checkpoint has not been consumed completely, then the list of restore ops will grow as more objects are added to the dependency graph.
这句话好像在说如果我们的checkpoing文件中的内容没有被全部加载,那么剩下没有加载的变量将会在我们使用这个图的时候动态的加载.
那么来做实验吧.首先我们定义了一个图,就是prosody-tacotron中的global style token(gst)层.这个图有两种模式,'reference'和'weight',在'reference'模式下模型参数会被完全定义,而在'weight'模式下,模型只会定义gst嵌入,而不会定义多头注意力层和卷积层的变量.下面是效果
hp.gst_mode = 'reference'
model = get_gst()
for i in model.trainable_variables:
print(i.name)
"""
输出:
gst_tokens:0
gst/multihead_attention/attention_v:0
gst/multihead_attention/attention_g:0
gst/multihead_attention/attention_b:0
gst/multihead_attention/conv1d/kernel:0
gst/multihead_attention/conv1d/bias:0
gst/multihead_attention/conv1d_1/kernel:0
gst/multihead_attention/conv1d_1/bias:0
"""
hp.gst_mode = 'weight'
model = get_gst()
for i in model.trainable_variables:
print(i.name)
"""
输出:
gst_tokens:0
"""
我们在'reference'模式下保存模型,并且在'weight'模式下读取模型,然后比较它们是否相同
hp.gst_mode = 'reference'
model1 = get_gst()
model1.save_weights('./checkpoint')
hp.gst_mode = 'weight'
model2 = get_gst()
model2.load_weights('./checkpoint')
for i in range(len(model2.trainable_variables)):
v1 = model1.trainable_variables[i]
v2 = model2.trainable_variables[i]
print(v1.name, end=": ")
print(tf.reduce_all(tf.equal(v1, v2)))
"""
输出:
gst_tokens:0: tf.Tensor(True, shape=(), dtype=bool)
"""
可以看到,只定义了部分模型参数的model2,它的参数与完整模型model1的对应部分相同.然后我们使用model2在'reference'模式下进行计算,再来比较二者的参数
hp.gst_mode = 'reference'
model2(np.zeros([32, 128], dtype=np.float32))
for i in range(len(model2.trainable_variables)):
v1 = model1.trainable_variables[i]
v2 = model2.trainable_variables[i]
print(v1.name, end=": ")
print(tf.reduce_all(tf.equal(v1, v2)))
"""
输出:
gst_tokens:0: tf.Tensor(True, shape=(), dtype=bool)
gst/multihead_attention/attention_v:0: tf.Tensor(True, shape=(), dtype=bool)
gst/multihead_attention/attention_g:0: tf.Tensor(True, shape=(), dtype=bool)
gst/multihead_attention/attention_b:0: tf.Tensor(True, shape=(), dtype=bool)
gst/multihead_attention/conv1d/kernel:0: tf.Tensor(True, shape=(), dtype=bool)
gst/multihead_attention/conv1d/bias:0: tf.Tensor(True, shape=(), dtype=bool)
gst/multihead_attention/conv1d_1/kernel:0: tf.Tensor(True, shape=(), dtype=bool)
gst/multihead_attention/conv1d_1/bias:0: tf.Tensor(True, shape=(), dtype=bool)
"""
我们可以看到model2被完全定义,且model1和model2的所有参数均相同.作为对比,如果model2不加载model1的参数,我们得到如下输出.
hp.gst_mode = 'reference'
model1 = get_gst()
model1.save_weights('./checkpoint')
hp.gst_mode = 'weight'
model2 = get_gst()
for i in range(len(model2.trainable_variables)):
v1 = model1.trainable_variables[i]
v2 = model2.trainable_variables[i]
print(v1.name, end=": ")
print(tf.reduce_all(tf.equal(v1, v2)))
"""
输出:
gst_tokens:0: tf.Tensor(False, shape=(), dtype=bool)
"""
hp.gst_mode = 'reference'
model2(np.zeros([32, 128], dtype=np.float32))
for i in range(len(model2.trainable_variables)):
v1 = model1.trainable_variables[i]
v2 = model2.trainable_variables[i]
print(v1.name, end=": ")
print(tf.reduce_all(tf.equal(v1, v2)))
"""
输出:
gst_tokens:0: tf.Tensor(False, shape=(), dtype=bool)
gst/multihead_attention/attention_v:0: tf.Tensor(False, shape=(), dtype=bool)
gst/multihead_attention/attention_g:0: tf.Tensor(True, shape=(), dtype=bool)
gst/multihead_attention/attention_b:0: tf.Tensor(True, shape=(), dtype=bool)
gst/multihead_attention/conv1d/kernel:0: tf.Tensor(False, shape=(), dtype=bool)
gst/multihead_attention/conv1d/bias:0: tf.Tensor(True, shape=(), dtype=bool)
gst/multihead_attention/conv1d_1/kernel:0: tf.Tensor(False, shape=(), dtype=bool)
gst/multihead_attention/conv1d_1/bias:0: tf.Tensor(True, shape=(), dtype=bool)
"""
在model2不加载model1保存的模型的情况下,model1和model2的参数不同,注意结果中相同的部分是因为我们将这些参数初始化为了一个常数.
那么如果我们换一个方式,让model1去加载model2保存的模型,会发生什么呢?下面是实验结果:
hp.gst_mode = 'weight'
model2 = get_gst()
model2.save_weights('./checkpoint')
hp.gst_mode = 'reference'
model1 = get_gst()
model1.load_weights('./checkpoint')
for i in range(len(model2.trainable_variables)):
v1 = model1.trainable_variables[i]
v2 = model2.trainable_variables[i]
print(v1.name, end=": ")
print(tf.reduce_all(tf.equal(v1, v2)))
"""
输出:
gst_tokens:0: tf.Tensor(True, shape=(), dtype=bool)
"""
hp.gst_mode = 'reference'
model2(np.zeros([32, 128], dtype=np.float32))
for i in range(len(model2.trainable_variables)):
v1 = model1.trainable_variables[i]
v2 = model2.trainable_variables[i]
print(v1.name, end=": ")
print(tf.reduce_all(tf.equal(v1, v2)))
"""
输出:
gst_tokens:0: tf.Tensor(True, shape=(), dtype=bool)
gst_1/multihead_attention/attention_v:0: tf.Tensor(False, shape=(), dtype=bool)
gst_1/multihead_attention/attention_g:0: tf.Tensor(True, shape=(), dtype=bool)
gst_1/multihead_attention/attention_b:0: tf.Tensor(True, shape=(), dtype=bool)
gst_1/multihead_attention/conv1d_2/kernel:0: tf.Tensor(False, shape=(), dtype=bool)
gst_1/multihead_attention/conv1d_2/bias:0: tf.Tensor(True, shape=(), dtype=bool)
gst_1/multihead_attention/conv1d_3/kernel:0: tf.Tensor(False, shape=(), dtype=bool)
gst_1/multihead_attention/conv1d_3/bias:0: tf.Tensor(True, shape=(), dtype=bool)
"""
可以看到,model1可以加载model2保存的模型,但由于model2只包含一部分的模型,所以model1也就只能加载这一部分,剩下的部分则会随机初始化,所以在model2被完全定义之后,对于剩余的参数,model1和model2则不同.