在TensorFlow中,如何从python的张量中获取非零值及其索引?

时间:2016-08-30 05:34:07

标签: python tensorflow indices

我想做这样的事情。
我们假设我们有一个张量A.

A = [[1,0],[0,4]]

我希望从中得到非零值及其指数。

Nonzero values: [1,4]  
Nonzero indices: [[0,0],[1,1]]

Numpy也有类似的操作 np.flatnonzero(A)返回在展平的A.中非零的索引 x.ravel()[np.flatnonzero(x)]根据非零指数提取元素 这些操作是a link

如何在Tensorflow中使用python执行上述Numpy操作? (矩阵是否扁平化并不重要。)

1 个答案:

答案 0 :(得分:32)

您可以使用not_equalwhere方法在Tensorflow中获得相同的结果。

zero = tf.constant(0, dtype=tf.float32)
where = tf.not_equal(A, zero)

where是与ATrue False形状相同的张量,在下列情况下

[[True, False],
 [False, True]]

这足以从A中选择零或非零元素。如果要获取索引,可以使用where方法,如下所示:

indices = tf.where(where)

where张量有两个True值,因此indices张量将有两个条目。 where张量的等级为2,因此条目将有两个索引:

[[0, 0],
 [1, 1]]