在张量流中,我们如何在稀疏张量中执行tf.gather或tf.gather_nd?如何在不将其转换为密集张量的情况下从稀疏张量中提取选择的特定行或特定元素?
答案 0 :(得分:0)
这是一种可能的解决方案,尽管它在时间和内存上仍然很昂贵,因此对于大的用例来说可能不可行:
import tensorflow as tf
def sparse_select_indices(sp_input, indices, axis=0):
# Only necessary if indices may have non-unique elements
indices, _ = tf.unique(indices)
n_indices = tf.size(indices)
# Only necessary if indices may not be sorted
indices, _ = tf.math.top_k(indices, n_indices)
indices = tf.reverse(indices, [0])
# Get indices for the axis
idx = sp_input.indices[:, axis]
# Find where indices match the selection
eq = tf.equal(tf.expand_dims(idx, 1), tf.cast(indices, tf.int64))
# Mask for selected values
sel = tf.reduce_any(eq, axis=1)
# Selected values
values_new = tf.boolean_mask(sp_input.values, sel, axis=0)
# New index value for selected elements
n_indices = tf.cast(n_indices, tf.int64)
idx_new = tf.reduce_sum(tf.cast(eq, tf.int64) * tf.range(n_indices), axis=1)
idx_new = tf.boolean_mask(idx_new, sel, axis=0)
# New full indices tensor
indices_new = tf.boolean_mask(sp_input.indices, sel, axis=0)
indices_new = tf.concat([indices_new[:, :axis],
tf.expand_dims(idx_new, 1),
indices_new[:, axis + 1:]], axis=1)
# New shape
shape_new = tf.concat([sp_input.dense_shape[:axis],
[n_indices],
sp_input.dense_shape[axis + 1:]], axis=0)
return tf.SparseTensor(indices_new, values_new, shape_new)
以下是使用示例:
import tensorflow as tf
with tf.Session() as sess:
# Input
sp1 = tf.SparseTensor([[0, 1], [2, 3], [4, 5]], [10, 20, 30], [6, 7])
print(sess.run(tf.sparse.to_dense(sp1)))
# [[ 0 10 0 0 0 0 0]
# [ 0 0 0 0 0 0 0]
# [ 0 0 0 20 0 0 0]
# [ 0 0 0 0 0 0 0]
# [ 0 0 0 0 0 30 0]
# [ 0 0 0 0 0 0 0]]
# Select rows 0, 1, 2
sp2 = sparse_select_indices(sp1, [0, 1, 2])
print(sess.run(tf.sparse.to_dense(sp2)))
# [[ 0 10 0 0 0 0 0]
# [ 0 0 0 0 0 0 0]
# [ 0 0 0 20 0 0 0]]
# Select columns 4, 5
sp3 = sparse_select_indices(sp1, [4, 5], axis=1)
print(sess.run(tf.sparse.to_dense(sp3)))
# [[ 0 0]
# [ 0 0]
# [ 0 0]
# [ 0 0]
# [ 0 30]
# [ 0 0]]