将tf.nn.top_n的输出转换为稀疏矩阵

时间:2017-04-05 07:14:29

标签: python matrix indexing tensorflow sparse-matrix

正如标题所述,我试图从张量流中的矩阵中提取每行最高的n个元素,并将结果存储在稀疏的Tensor中。

我已经能够使用tf.nn.top_n提取索引和值,但索引不符合tf.SparseTensor要求的约定。

具体来说,tf.nn.top_n返回一个col索引矩阵,其形状与结果值矩阵(Rows xn)相同,而tf.SparseTensor想要一个(#non-zero x 2)矩阵,每个非1行-zero元素和包含row和col索引的列。

这些值可能是一个类似的问题,需要一个非零元素列表而不是一个值矩阵。

如何在这些索引表示法方案之间快速转换?

1 个答案:

答案 0 :(得分:2)

这可以通过一些模运算来实现。这是一个适用于矩阵的示例,尽管可以循环更多的轴。

import tensorflow as tf

def slices_to_dims(slice_indices):
  """
  Args:
    slice_indices: An [N, k] Tensor mapping to column indices.
  Returns:
    An index Tensor with shape [N * k, 2], corresponding to indices suitable for
    passing to SparseTensor.
  """
  slice_indices = tf.cast(slice_indices, tf.int64)
  num_rows = tf.shape(slice_indices, out_type=tf.int64)[0]
  row_range = tf.range(num_rows)
  item_numbers = slice_indices * num_rows + tf.expand_dims(row_range, axis=1)
  item_numbers_flat = tf.reshape(item_numbers, [-1])
  return tf.stack([item_numbers_flat % num_rows, 
                   item_numbers_flat // num_rows], axis=1)

使用示例:

dense_shape = [5, 7]
dense_matrix = tf.random_normal(shape=dense_shape)
top_values, top_indices = tf.nn.top_k(dense_matrix, k=2)
sparse_indices = slices_to_dims(top_indices)
sparse_tensor = tf.sparse_reorder(tf.SparseTensor(
    indices=sparse_indices, 
    values=tf.reshape(top_values, [-1]),
    dense_shape=dense_shape))
densified_top = tf.sparse_tensor_to_dense(sparse_tensor)
with tf.Session() as session:
  sparse_top, dense_original, dense_selected = session.run(
      [sparse_tensor, dense_matrix, densified_top])
  print(dense_original)
  print(dense_selected)
  print(sparse_top)

打印:

[[ 1.44056129 -1.01790774 -0.2795608   2.34854746 -2.27528405 -0.62035948
   3.36598897]
 [ 0.7114948  -0.42564821 -0.93446779 -0.25373486 -0.51730365  0.72331643
  -0.75625718]
 [-0.6501748  -0.92748415 -0.95409006 -0.07157528  0.80637723 -0.32177576
  -1.4516511 ]
 [-1.081038   -0.67226124 -1.19455576  0.44537872 -0.69019234 -0.61539739
   0.15328468]
 [ 0.43032476 -0.11295394  0.83491379 -0.67906654  0.20325914 -0.0155068
   0.52107805]]
[[ 0.          0.          0.          2.34854746  0.          0.
   3.36598897]
 [ 0.7114948   0.          0.          0.          0.          0.72331643
   0.        ]
 [ 0.          0.          0.         -0.07157528  0.80637723  0.          0.        ]
 [ 0.          0.          0.          0.44537872  0.          0.
   0.15328468]
 [ 0.          0.          0.83491379  0.          0.          0.
   0.52107805]]
SparseTensorValue(indices=array([[0, 3],
       [0, 6],
       [1, 0],
       [1, 5],
       [2, 3],
       [2, 4],
       [3, 3],
       [3, 6],
       [4, 2],
       [4, 6]]), values=array([ 2.34854746,  3.36598897,  0.7114948 ,  0.72331643, -0.07157528,
        0.80637723,  0.44537872,  0.15328468,  0.83491379,  0.52107805], dtype=float32), dense_shape=array([5, 7]))