如何在tensorflow中实现矩阵掩码操作?

时间:2017-08-20 17:14:28

标签: python numpy image-processing tensorflow

我有一个案例需要在tensorflow中的图像处理应用程序中填充一些漏洞(缺少数据)。 '洞''很容易找到,因为它们是零,好的数据不是零。我想用随机数据填补漏洞。这很容易使用python numpy,但在tensorflow中执行它需要一些工作。我想出了一个解决方案,想看看是否有更好或更有效的方法来做同样的事情。我知道tensorflow还不支持更高级的numpy类型索引,但是有一个函数tf.gather_nd()似乎很有希望。但是,我无法从文档中看出如何将它用于我想做的事情。我会很感激能够改进我所做的事情的答案,或者特别是如果有人可以使用tf.gather_nd()告诉我如何使用它。此外,tf.boolean_mask()不适用于我尝试执行的操作,因为它不允许您将输出用作索引。在python我想要做的事情:

a = np.ones((2,2))
a[0,0]=a[0,1] = 0
mask = a == 0
a[mask] = np.random.random_sample(a.shape)[mask]
print('new a = ', a)

我最终在Tensorflow中做了同样的事情(跳过填充数组步骤)

zeros = tf.zeros(tf.shape(a))  
mask = tf.greater(a,zeros)
mask_n = tf.equal(a,zeros)
mask = tf.cast(mask,tf.float32)
mask_n = tf.cast(mask_n,tf.float32
r = tf.random_uniform(tf.shape(a),minval = 0.0,maxval=1.0,dtype=tf.float32)
r_add = tf.multiply(mask_n,r)
targets = tf.add(tf.multiply(mask,a),r_add)

2 个答案:

答案 0 :(得分:3)

我认为这三行可能会做你想要的。首先,你做一个面具。然后,您创建随机数据。最后,用随机数据填充屏蔽值。

mask = tf.equal(a, 0.0)
r = tf.random_uniform(tf.shape(a), minval = 0.0,maxval=1.0,dtype=tf.float32)
targets = tf.where(mask, r, a)

答案 1 :(得分:1)

您可以使用tf.where来实现相同的目标:

A = tf.Variable(a)
B = tf.where(A==0., tf.random_normal(A.get_shape()), tf.cast(A, tf.float32))