我创建了一个包含一些变量的函数func
。现在,我希望单独使用此函数并通过tf.map_fn
函数使用,我希望为这两种情况保留相同的变量集。但是,显然tf.map_fn
函数将当前变量范围附加map
,因此独立案例的变量范围不再与tf.map_fn
的案例匹配。因此,以下代码抛出错误,因为变量mul1/map/weights
在使用reuse=True
调用之前不存在。
import tensorflow as tf
D = 5
batch_size = 1
def func(x):
W = tf.get_variable(initializer=tf.constant_initializer(1), shape=[D,1], dtype=tf.float32, trainable=True, name="weights")
y = tf.matmul(x, W)
return y
x = tf.placeholder(tf.float32, [batch_size, 5])
x_cat = tf.placeholder(tf.float32, [None, batch_size, 5])
with tf.variable_scope("mul1") as mul1_scope:
y_sum = func(x)
with tf.variable_scope(mul1_scope, reuse=True):
cost = tf.map_fn(lambda x: func(x), x_cat)
这里我只想对mul1/map
范围内的变量运行渐变更新。因此,我可以在每次更新后使用tf.assign
更改mul1
范围内的变量(仅用于前馈步骤)。但这是进行变量共享的一种相当痛苦的方式。所以,我想知道是否有更好的方法来解决这个问题。任何帮助将非常感激 !