稀疏张量上的行或元素选择

时间:2018-11-29 14:57:47

标签: python tensorflow sparse-matrix

在张量流中,我们如何在稀疏张量中执行tf.gather或tf.gather_nd?如何在不将其转换为密集张量的情况下从稀疏张量中提取选择的特定行或特定元素?

1 个答案:

答案 0 :(得分:0)

这是一种可能的解决方案,尽管它在时间和内存上仍然很昂贵,因此对于大的用例来说可能不可行:

import tensorflow as tf

def sparse_select_indices(sp_input, indices, axis=0):
    # Only necessary if indices may have non-unique elements
    indices, _ = tf.unique(indices)
    n_indices = tf.size(indices)
    # Only necessary if indices may not be sorted
    indices, _ = tf.math.top_k(indices, n_indices)
    indices = tf.reverse(indices, [0])
    # Get indices for the axis
    idx = sp_input.indices[:, axis]
    # Find where indices match the selection
    eq = tf.equal(tf.expand_dims(idx, 1), tf.cast(indices, tf.int64))
    # Mask for selected values
    sel = tf.reduce_any(eq, axis=1)
    # Selected values
    values_new = tf.boolean_mask(sp_input.values, sel, axis=0)
    # New index value for selected elements
    n_indices = tf.cast(n_indices, tf.int64)
    idx_new = tf.reduce_sum(tf.cast(eq, tf.int64) * tf.range(n_indices), axis=1)
    idx_new = tf.boolean_mask(idx_new, sel, axis=0)
    # New full indices tensor
    indices_new = tf.boolean_mask(sp_input.indices, sel, axis=0)
    indices_new = tf.concat([indices_new[:, :axis],
                             tf.expand_dims(idx_new, 1),
                             indices_new[:, axis + 1:]], axis=1)
    # New shape
    shape_new = tf.concat([sp_input.dense_shape[:axis],
                           [n_indices],
                           sp_input.dense_shape[axis + 1:]], axis=0)
    return tf.SparseTensor(indices_new, values_new, shape_new)

以下是使用示例:

import tensorflow as tf

with tf.Session() as sess:
    # Input
    sp1 = tf.SparseTensor([[0, 1], [2, 3], [4, 5]], [10, 20, 30], [6, 7])
    print(sess.run(tf.sparse.to_dense(sp1)))
    # [[ 0 10  0  0  0  0  0]
    #  [ 0  0  0  0  0  0  0]
    #  [ 0  0  0 20  0  0  0]
    #  [ 0  0  0  0  0  0  0]
    #  [ 0  0  0  0  0 30  0]
    #  [ 0  0  0  0  0  0  0]]

    # Select rows 0, 1, 2
    sp2 = sparse_select_indices(sp1, [0, 1, 2])
    print(sess.run(tf.sparse.to_dense(sp2)))
    # [[ 0 10  0  0  0  0  0]
    #  [ 0  0  0  0  0  0  0]
    #  [ 0  0  0 20  0  0  0]]

    # Select columns 4, 5
    sp3 = sparse_select_indices(sp1, [4, 5], axis=1)
    print(sess.run(tf.sparse.to_dense(sp3)))
    # [[ 0  0]
    #  [ 0  0]
    #  [ 0  0]
    #  [ 0  0]
    #  [ 0 30]
    #  [ 0  0]]