如何从Sparse Tensor中仅获取非零值

时间:2017-02-09 11:07:56

标签: tensorflow

利用TensorFlow的HashTable查找实现,我将SparseTensor恢复为提供的默认值。我想清除它并获得没有默认值的最终SparseTensor。

如何清除该默认值?我不介意默认值是为了实现这一目标。 0很好,因此是-1。

1 个答案:

答案 0 :(得分:0)

tf.sparse_retain应该有效:

def sparse_remove(sparse_tensor, remove_value=0.):
  return tf.sparse_retain(sparse_tensor, tf.not_equal(a.values, remove_value))

举个例子:

import tensorflow as tf

a = tf.SparseTensor(indices=[[1, 2], [2, 2]], values=[0., 1.], shape=[3, 3])
with tf.Session() as session:
  print(session.run([a, sparse_remove(a)]))

打印(我稍微重新格式化了):

[SparseTensorValue(indices=array([[1, 2], [2, 2]]), values=array([ 0.,  1.], dtype=float32), shape=array([3, 3])), 
 SparseTensorValue(indices=array([[2, 2]]), values=array([ 1.], dtype=float32), shape=array([3, 3]))]