请考虑以下包含张量流tf.cond()
的代码段。
import tensorflow as tf
import numpy as np
bb = tf.placeholder(tf.bool)
xx = tf.placeholder(tf.float32, name='xx')
yy = tf.placeholder(tf.float32, name='yy')
zz = tf.cond(bb, lambda: xx + yy, lambda: 100 + yy)
with tf.Session() as sess:
dict1 = {bb:False, yy:np.array([1., 3, 4]), xx:np.array([5., 6, 7])}
print(sess.run(zz, feed_dict=dict1)) # works fine without errors
dict2 = {bb:False, yy:np.array([1., 3, 4])}
print(sess.run(zz, feed_dict=dict2)) # get an InvalidArgumentError asking to
# provide an input for xx
在这两种情况下,bb
都是False
,zz
的评估理论上不依赖xx
,但仍然需要输入xx
张量流。即使它可以作为虚拟数组提供,但它必须与yy
的形状匹配,并且不像dict2
那样干净。
有人可以建议如何评估zz
(使用tf.cond()
或任何其他方法)而不提供xx
的值吗?
答案 0 :(得分:8)
您可以将xx
定义为tf.Variable
,并为其指定一个默认值(只要xx
没有提供其他值,就会使用该值)。有几点需要注意:
xx
不是占位符,但您仍然可以通过feed_dict
将值添加到其中来对待它。validate_shape=False
,以便将任何形状投放到xx
。trainable=False
以便xx
未优化(否则,优化程序可能会将其默认值更改为Nan
,这可能会导致问题。)xx
初始化tf.global_variables_initializer()
的值。以下是代码:
import tensorflow as tf
import numpy as np
bb = tf.placeholder(tf.bool)
xx = tf.Variable(initial_value=0.0,validate_shape=False,trainable=False,name='xx')
yy = tf.placeholder(tf.float32, name='yy')
zz = tf.cond(bb, lambda: xx + yy, lambda: 100 + yy)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
dict1 = {bb:False, yy:np.array([1., 3, 4]), xx:np.array([5., 6, 7])}
print(sess.run(zz, feed_dict=dict1))
dict2 = {bb:False, yy:np.array([1., 3, 4])}
print(sess.run(zz, feed_dict=dict2))