共享变量
import tensorflow.compat.v1 as tf
def main():
"""
共享变量实践
1. 为什么需要共享变量
在某些情况下,一个模型需要使用其他模型创建的变量,
两个模型一起训练(比如对抗网络)
"""
var1 = tf.Variable(1.0, name='firstvar')
var2 = tf.Variable(2.0, name='firstvar')
var3 = tf.Variable(3.0, name='firstvar')
var4 = tf.Variable(4.0)
var5 = tf.Variable(5.0)
print('var1: ', var1.name)
print('var2: ', var2.name)
print('var3: ', var3.name)
print('var4: ', var4.name)
print('var5: ', var5.name)
"""
结果:
var1: firstvar:0
var2: firstvar_1:0
var3: firstvar_2:0
var4: Variable:0
var5: Variable_1:0
分析:
1. `tf.Vairiable()`只能创建新的变量;
2. 如果创建同名变量,tf会默认给新的变量加索引。
所以,虽然定义的变量使用使用了相同的名称,
但其实创建的是三个不同的变量,一个变量的改
变不会影响其他变量的值;
3. 如果tf.Variable()创建变量时不指定名称,则系统
会自动指定名称。
"""
var1 = tf.get_variable(
name='firstvar',
shape=[1],
dtype=tf.float32,
initializer=tf.constant_initializer())
print('var1: ', var1.name)
"""
结果:
var1: firstvar_3:0
分析:
1. 使用`tf.get_variable()`生成的变量是以指定的`name`
属性为 唯一标识,并不是定义的变量名称。使用时一般通过
`name`属性定位到具体位置,并将其共享到其他模型中。
2. 由于变量`firstvar`在前面使用Variable函数生成过一次,
所以系统自动变成了firstvar_3:0
"""
"""
var1 = tf.get_variable(
name='firstvar',
shape=[1],
dtype=tf.float32,
initializer=tf.constant_initializer())
"""
"""
结果:
ValueError: Variable var1 already exists, disallowed.
分析:
程序发生了崩溃,使用`get_variable()`只能定义一次指定名称的变量。
"""
"""
`get_variable()`一般会配合`variable_scope`一起使用,以实现共享变量
"""
with tf.variable_scope('test1'):
var1 = tf.get_variable(
name='firstvar',
shape=[1],
dtype=tf.float32,
initializer=tf.constant_initializer(0.3))
print('var1: ', var1.name)
with tf.variable_scope('test2'):
var2 = tf.get_variable(
name='firstvar',
shape=[1],
dtype=tf.float32,
initializer=tf.constant_initializer(0.3))
print('var2: ', var2.name)
"""
结果:
var1: var1: test1/firstvar:0
var1: var2: test2/firstvar:0
分析:
var1和var2都使用firstvar的名字来定义。从输出结果可以看出,
生成的两个变量var1和var2是不同的,他们作用在不同的scope下,
这就是scope的作用。
"""
with tf.variable_scope('test1'):
with tf.variable_scope('test2'):
var2 = tf.get_variable(
name='firstvar',
shape=[1],
dtype=tf.float32,
initializer=tf.constant_initializer())
print('var2: ', var2.name)
"""
结果:
var1: test1/test2/firstvar:0
分析:
`variable_scope`支持嵌套
"""
"""
使用`variable_scope`的`reuse`参数来实现共享变量功能
reuse=True 表示使用已经定义过的变量,此时`get_variable`
不会再创建新的变量,而是去 `图` 中被get_variable所
创建过的变量中找与`name`相同的变量。
"""
with tf.variable_scope('test1', reuse=True):
var3 = tf.get_variable(
name='firstvar',
shape=[1],
dtype=tf.float32,
initializer=tf.constant_initializer())
print('var3: ', var3.name)
"""
结果:
var3: test1/firstvar:0
分析:
var3与var1输出的名字是一样的,此时就实现了变量共享。
"""
with tf.variable_scope('test1', reuse=True):
with tf.variable_scope('test2'):
var4 = tf.get_variable(
name='firstvar',
shape=[1],
dtype=tf.float32,
initializer=tf.constant_initializer())
print('var4: ', var4.name)
"""
结果:
var4: test1/test2/firstvar:0
分析:
var4的名字与var2的名字相同,这表明var4与var2公用了
一个变量。可以看到,虽然scope test2没有指定reuse=True,
但上层空间test1指定了resuse=True,即variable_scope
的reuse具有继承关系。
"""
"""
`variable_scope`和`get_variable`都有初始化的功能,且作用域
的初始化方法能够被继承
"""
with tf.variable_scope('test1', reuse=tf.AUTO_REUSE, initializer=tf.constant_initializer(0.4)):
var1 = tf.get_variable('secondvar', shape=[2], dtype=tf.float32)
with tf.variable_scope('test2', reuse=tf.AUTO_REUSE):
var2 = tf.get_variable('secodvar', shape=[2], dtype=tf.float32)
var3 = tf.get_variable('var3', shape=[2], dtype=tf.float32, initializer=tf.constant_initializer(0.3))
print('var1: ', var1)
print('var2: ', var2)
print('var3: ', var3)
"""
结果:
var1:
var2:
var3:
分析:
1. 在`variable_scope`中使用`tf.AUTO_REUSE`为reuse属性赋值
因为之前已经使用`test1/firstvar:0`和`test1/test2/firstvar:0`
定义过变量。`tf.AUTO_REUSE`可以实现第一次调用`variable_scope`
时传入的reuse值为False,再次调用`variable_scope`时,传入的
reuse值自动变为True。
2. 将test1作用域初始化为4.0,var1没有初始化时,初始值继承test1作用域的初始化值。
3. test2作用域没有定义初始化方法,var2的初始值也为4.0,表明test2作用域的初始值
继承了test1作用域的初始值。
4. test1/test2作用域下的var3定义了初始化方法,则不再继承作用域的初始化方法。
"""
if __name__ == '__main__':
main()