Tensorflow:tf.map_fn的可变重用

时间:2016-07-20 01:16:43

标签: python tensorflow deep-learning

我创建了一个包含一些变量的函数func。现在,我希望单独使用此函数并通过tf.map_fn函数使用,我希望为这两种情况保留相同的变量集。但是,显然tf.map_fn函数将当前变量范围附加map,因此独立案例的变量范围不再与tf.map_fn的案例匹配。因此,以下代码抛出错误,因为变量mul1/map/weights在使用reuse=True调用之前不存在。

    import tensorflow as tf
    D = 5
    batch_size = 1
    def func(x):
        W = tf.get_variable(initializer=tf.constant_initializer(1), shape=[D,1], dtype=tf.float32, trainable=True, name="weights")
        y = tf.matmul(x, W)
        return y

    x = tf.placeholder(tf.float32, [batch_size, 5])
    x_cat = tf.placeholder(tf.float32, [None, batch_size, 5])
    with tf.variable_scope("mul1") as mul1_scope:
        y_sum = func(x)
    with tf.variable_scope(mul1_scope, reuse=True):
        cost = tf.map_fn(lambda x: func(x), x_cat)

这里我只想对mul1/map范围内的变量运行渐变更新。因此,我可以在每次更新后使用tf.assign更改mul1范围内的变量(仅用于前馈步骤)。但这是进行变量共享的一种相当痛苦的方式。所以,我想知道是否有更好的方法来解决这个问题。任何帮助将非常感激 !

0 个答案:

没有答案