避免在tensorflow中重复图形(LSTM模型)

时间:2018-03-05 15:50:14

标签: python tensorflow while-loop lstm recurrent-neural-network

我有以下简化代码(实际上,展开的LSTM模型):

def func(a, b):
    with tf.variable_scope('name'):
        res = tf.add(a, b)
    print(res.name)
    return res

func(tf.constant(10), tf.constant(20))

每当我运行最后一行时,它似乎都会改变图形。但我不想让图表发生变化。实际上我的代码是不同的,是一个神经网络模型,但它太大了,所以我添加了上面的代码。我想调用func而不更改模型图,但它会发生变化。我在TensorFlow中读到了变量范围,但似乎我根本不理解它。

1 个答案:

答案 0 :(得分:3)

您应该查看tf.nn.dynamic_rnn的源代码,特别是python/ops/rnn.py处的_dynamic_rnn_loop函数 - 它解决了同样的问题。为了不炸毁图表,它使用tf.while_loop重新使用相同的图表操作来获取新数据。但是这种方法增加了一些限制,即在循环中传递的张量的形状必须是不变的。请参阅tf.while_loop文档中的示例:

i0 = tf.constant(0)
m0 = tf.ones([2, 2])
c = lambda i, m: i < 10
b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
tf.while_loop(
    c, b, loop_vars=[i0, m0],
    shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])