向TensorFlow tf.case馈送数组(形状为1级的形状)

时间:2019-06-17 16:02:23

标签: python tensorflow

以下来自tf.case文档的示例:

def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
            default=f3, exclusive=True)

我想做同样的事情,但是允许使用feed_dict作为输入,如以下片段所示:

x = tf.placeholder(tf.float32, shape=[None])
y = tf.placeholder(tf.float32, shape=[None])
z = tf.placeholder(tf.float32, shape=[None])
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
            default=f3, exclusive=True)
print(sess.run(r, feed_dict={x: [0, 1, 2, 3], y: [1, 1, 1, 1], z: [2, 2, 2, 2]}))
# result should be [17, -1, -1, 23]

因此,基本上,我想提供三个相等长度的int数组,并接收包含17、23或-1的int值数组。不幸的是,上面的代码给出了错误:

  

ValueError:形状必须为0,但对于“ case / cond / Switch”(操作数为“ Switch”),输入形状为[?],[?]则为1级。

我知道,tf.case需要布尔型标量张量输入值,但是有什么方法可以实现我想要的?我也尝试了tf.cond,但没有成功。

1 个答案:

答案 0 :(得分:0)

为此使用tf.where,例如这样的{broadcasting support for tf.where seems to be on its way,但据我所知还没有,所以您必须确保所有参数的大小都相同,且向量为或tf.filltf.tile ...)。

The property 'TranslationModelID' was not found in type 'PassingAttachedPropertyInControlTemplater.Translation'. [Line: 17 Position: 40]