如何从tf.SparseTensor中删除显式0?

时间:2018-05-06 04:03:15

标签: python tensorflow sparse-matrix

在我的模型训练的每个时代中,tf.SparseTensor更改其值以使其具有更明确的零。删除这样的显式零将使显式边的数量变小,从而使整个计算更快。

所以,我需要一种方法从tf.SparseTensor中删除明确的零,使其更“苗条”。有谁知道这样做的方法?

1 个答案:

答案 0 :(得分:0)

您可以使用tf.sparse_retain() op:

解决此问题
st = ...  # A `tf.SparseTensor` object.

# Compute a vector of booleans indicating which values of `st` should be dropped
# (if False) or retained (if True)
is_nonzero = tf.not_equal(st.values, 0)

# `tf.sparse_retain()` computes a new `tf.SparseTensor` with the specified values
# retained in the output.
st_without_zeros = tf.sparse_retain(st, is_nonzero)