解决方法到tf.where()形状输出

时间:2017-07-11 22:00:13

标签: python tensorflow

我想知道是否有办法绕过tf.where()输出的数组形状。例如,这里是我试图运行的代码:

import tensorflow as tf

with tf.Session() as sess:
    a = tf.constant([[[1,0],[0,2]],[[0,3],[4,0]]]).eval()
    c = tf.where(tf.equal(a,0)).eval()
    c = tf.multiply(100,c).eval()
    c = tf.add(a,c)

    print(c.eval())

我对输出的预期是:

[[[  1,100]
  [100,  2]

 [[100,  3]
  [  4,100]]]

但是,由于tf.where()输出我的代码的方式是4x3张量而不是2x2x2,因此存在错误。是否有另一个命令集我可以使用100'来有效地替换所有零?此方法适用于二维数组。

1 个答案:

答案 0 :(得分:0)

您可以创建一个100的张量,然后在where调用中使用它。根据{{​​3}},此函数需要两个可选的张量来读取值。

with tf.Session() as sess:
    a = tf.constant([[[1,0],[0,2]],[[0,3],[4,0]]])
    h = tf.multiply(tf.ones(a.shape, tf.int32), 100)
    c = tf.where(tf.equal(a, 0), h, a)

    print(c.eval())