我可以使用tf.Variable运行我的tensorflow代码但是tf.get_variable更有效率。上述错误由以下代码产生:
w = tf.get_variable(name='weights',
shape=filter_shape,
initializer=tf.random_normal_initializer(0., 0.01))
b = tf.get_variable(name='biases',
shape=filter_shape[-1],
initializer=tf.constant_initializer(0.))
我无法理解原因。有什么想法吗?
答案 0 :(得分:4)
tf.get_variable
使用变量范围来启用变量共享。以下是对how to share variables的解释。
具体来说,我倾向于使用以下框架将变量初始化与获取变量分开。
def initialize_variables(scope_name, shape):
'''initialize variables within variable scope_name.'''
with tf.variable_scope(scope_name, reuse=None) as scope:
w = tf.get_variable("weight", shape, initializer = random_normal_initializer(0., 0.01)))
b = tf.get_variable("biase", shape[-1], initializer = tf.constant_initializer(0.0))
scope.reuse_variables()
def fetch_variables(scope_name):
'''fetch variables within variable scope_name'''
with tf.variable_scope(scope_name, reuse=True):
w = tf.get_variable("weight")
b = tf.get_variable("biase")
return w, b
请注意,reuse=None
功能中的initialize_variables
设置会根据给定的w
设置重新b
和initializer
。在fetch_variables
中,reuse=True
设置启用了变量共享。