过滤出张量中的非零值

时间:2017-02-12 22:29:25

标签: python python-2.7 python-3.x numpy tensorflow

假设我有一个数组:'\n',我想过滤掉input = np.array([[1,0,3,5,0,8,6]])

我知道您可以将[1,3,5,8,6]与条件一起使用,但返回的值仍然为0。以下代码段的输出为tf.where。我也不明白为什么[[[1 0 3 5 0 8 6]]]同时需要tf.wherex

无论如何,我可以摆脱由此产生的张量中的0?

y

2 个答案:

答案 0 :(得分:3)

首先创建一个布尔掩码,以确定条件的真实位置;然后将蒙版应用于张量,如下所示。你可以使用tf.where来索引 - 但是它使用x& y返回一个与输入相同等级的张量,所以没有进一步的工作,你可以达到的最好就是[[[1 -1 3 5 - 1 8 6]]]用你要识别的东西改变-1以便以后删除。只使用where(没有x& y)将为您提供条件为真的所有值的索引,因此如果您喜欢这样,可以使用索引创建解决方案。我的建议如下是最清晰的。

import numpy as np
import tensorflow as tf
input = np.array([[1,0,3,5,0,8,6]])
X = tf.placeholder(tf.int32,[None,7])
zeros = tf.cast(tf.zeros_like(X),dtype=tf.bool)
ones = tf.cast(tf.ones_like(X),dtype=tf.bool)
loc = tf.where(input!=0,ones,zeros)
result=tf.boolean_mask(input,loc)
with tf.Session() as sess:
 out = sess.run([result],feed_dict={X:input})
 print (np.array(out))

答案 1 :(得分:0)

将数字强制转换为布尔型会将零标识为False。然后,您可以像往常一样遮罩。示例:

x = [1,0,2]
mask = tf.cast(x, dtype=tf.bool)  # [True, False, True]
nonzero_x = tf.boolean_mask(x, mask)  # [1, 2]