Tensorflow,如何保留中间节点值并重用它们

时间:2017-07-25 16:09:01

标签: tensorflow

假设我有以下图表 - X3Z是我关心的值 - XY是输入。在每个不同的。迭代,XY的即将到来的值和形状是不同的,所以我认为它们很多placeholder
- 情况是我需要在不同的时间点运行此图表两次以异步方式获取X3Z

+---+     op: +1           op: *3
| X +------------> X_1 +-----------> X3            +---+
+---+               +                +             | Y |
                    |                |             +-+-+ 
                    |             op:add             |
                    |                |               |
                    |                |               |
                    |   op: add      v     op:add    |
                    +------------->     <------------+
                                     Z

在早期时间点,我收到了输入X(说X=7,我不知道Y是什么此时此刻)。我希望看到X3的价值。所以我执行sess.run([X3], {X:7}),然后按预期返回24

在稍后的时间点,我得到另一个输入Y(比如Y=8),这次我只想看一下节点Z 。但重点是我必须执行sess.run([Z], {X:7, Y:8})才能获得结果。

问题是,对于以后的运行,我必须再次提供X以重新计算中间节点X_1X3。它计算流量X--> X_1 --> X3两次会损害效率。

我的想法是X_1X3在早期运行之后将包含值(X_1=8X3=24),直到图表被销毁,然后我可以直接利用而不是重新计算。

有没有办法实现这个目标?

2 个答案:

答案 0 :(得分:2)

以下内容并未完全解决您的问题,但是再次提供X就会失败:

X_temp = tf.Variable(0, dtype=tf.int32)
X = tf.placeholder_with_default(X_temp, shape=())
Y = tf.placeholder(tf.int32, shape=())
X_temp = tf.assign(X_temp, X)

X_1 = X_temp + 1
X3 = X_1 * 3
Z = X_1 + X3 + Y

sess = tf.InteractiveSession()
print(sess.run(X3, {X:7}))
print(sess.run(Z, {Y:8}))

#24
#40

答案 1 :(得分:0)

我建议的一个选择是:

    temp_X1, temp_X3 = sess.run([X_1, X3], feed_dict={X:7})
    sess.run(Z, feed_dict={X_1:temp_X1, X3:temp_X3, Y: 8}

您无需将所有内容存储在tf图表中 有关其他选项,请参阅tensorflow doc(例如使用Saver等)

注意:文档建议您加入placeholder,但加入中级Tensor可以最简单地满足您的要求。