我想知道是否有办法绕过tf.where()输出的数组形状。例如,这里是我试图运行的代码:
import tensorflow as tf
with tf.Session() as sess:
a = tf.constant([[[1,0],[0,2]],[[0,3],[4,0]]]).eval()
c = tf.where(tf.equal(a,0)).eval()
c = tf.multiply(100,c).eval()
c = tf.add(a,c)
print(c.eval())
我对输出的预期是:
[[[ 1,100]
[100, 2]
[[100, 3]
[ 4,100]]]
但是,由于tf.where()输出我的代码的方式是4x3张量而不是2x2x2,因此存在错误。是否有另一个命令集我可以使用100'来有效地替换所有零?此方法适用于二维数组。
答案 0 :(得分:0)
您可以创建一个100的张量,然后在where
调用中使用它。根据{{3}},此函数需要两个可选的张量来读取值。
with tf.Session() as sess:
a = tf.constant([[[1,0],[0,2]],[[0,3],[4,0]]])
h = tf.multiply(tf.ones(a.shape, tf.int32), 100)
c = tf.where(tf.equal(a, 0), h, a)
print(c.eval())