Tensorflow使op赋予显式依赖性以计算张量

时间:2017-10-25 10:12:49

标签: python tensorflow

我希望能够在每次运行另一个张量时隐式运行assign Op,这取决于在tf.Variable操作期间更改的assign。我不想每一步都手动运行assign Op。我尝试了两种不同的方法。这是一个简单的示例说明:

target_prob     = tf.placeholder(dtype=tf.float32, shape=[None, 2])
target_var      = tf.Variable(0, trainable=False, dtype=tf.float32)
init_target_var = tf.assign(target_var, tf.zeros_like(target_prob),
                            validate_shape=False)

# First approach
with tf.control_dependencies([init_target_var]):
  result = target_prob + target_var

# Second approach
# [target_var] = tf.tuple([target_var], control_inputs=[init_target_var])
# result = target_prob + target_var

sess = tf.Session()
sess.run(tf.global_variables_initializer())
res1 = sess.run(result, feed_dict={target_prob: np.ones([10, 2], dtype=np.float32)})
res2 = sess.run(result, feed_dict={target_prob: np.ones([12, 2], dtype=np.float32)})

在计算InvalidArgumentError (see above for traceback): Incompatible shapes: [12,2] vs. [10,2]时,两者都失败并显示错误res2。如果我这样做,这一切都有效:

res1 = sess.run(result, feed_dict={target_prob: np.ones([10, 2], dtype=np.float32)})
sess.run(init_target_var, feed_dict={target_prob: np.ones([12, 2], dtype=np.float32)})
res2 = sess.run(result, feed_dict={target_prob: np.ones([12, 2], dtype=np.float32)})

但同样,明确地运行init_target_var正是我想要避免的。

P.S。以上只是一个简单的例子。我的最终目标是使用来自tf.scatter_add的结果张量,不幸的是需要一个可变张量作为输入。

1 个答案:

答案 0 :(得分:0)

对于遇到这种情况的人来说,我实际上在计算result时使用了错误的张量。正确的代码是:

import tensorflow as tf
import numpy as np

target_prob         = tf.placeholder(dtype=tf.float32, shape=[None, 2])
tmp_var             = tf.Variable(0, trainable=False, dtype=tf.float32, validate_shape=False)
target_var          = tf.assign(tmp_var, tf.zeros_like(target_prob), validate_shape=False)

with tf.control_dependencies([target_var]):
  result = target_prob + target_var

sess = tf.Session()
sess.run(tf.global_variables_initializer())

res1 = sess.run(result, feed_dict={target_prob: np.ones([10, 2], dtype=np.float32)})
res2 = sess.run(result, feed_dict={target_prob: np.ones([12, 2], dtype=np.float32)})