正如标题所述,我试图从张量流中的矩阵中提取每行最高的n个元素,并将结果存储在稀疏的Tensor中。
我已经能够使用tf.nn.top_n提取索引和值,但索引不符合tf.SparseTensor要求的约定。
具体来说,tf.nn.top_n返回一个col索引矩阵,其形状与结果值矩阵(Rows xn)相同,而tf.SparseTensor想要一个(#non-zero x 2)矩阵,每个非1行-zero元素和包含row和col索引的列。
这些值可能是一个类似的问题,需要一个非零元素列表而不是一个值矩阵。
如何在这些索引表示法方案之间快速转换?
答案 0 :(得分:2)
这可以通过一些模运算来实现。这是一个适用于矩阵的示例,尽管可以循环更多的轴。
import tensorflow as tf
def slices_to_dims(slice_indices):
"""
Args:
slice_indices: An [N, k] Tensor mapping to column indices.
Returns:
An index Tensor with shape [N * k, 2], corresponding to indices suitable for
passing to SparseTensor.
"""
slice_indices = tf.cast(slice_indices, tf.int64)
num_rows = tf.shape(slice_indices, out_type=tf.int64)[0]
row_range = tf.range(num_rows)
item_numbers = slice_indices * num_rows + tf.expand_dims(row_range, axis=1)
item_numbers_flat = tf.reshape(item_numbers, [-1])
return tf.stack([item_numbers_flat % num_rows,
item_numbers_flat // num_rows], axis=1)
使用示例:
dense_shape = [5, 7]
dense_matrix = tf.random_normal(shape=dense_shape)
top_values, top_indices = tf.nn.top_k(dense_matrix, k=2)
sparse_indices = slices_to_dims(top_indices)
sparse_tensor = tf.sparse_reorder(tf.SparseTensor(
indices=sparse_indices,
values=tf.reshape(top_values, [-1]),
dense_shape=dense_shape))
densified_top = tf.sparse_tensor_to_dense(sparse_tensor)
with tf.Session() as session:
sparse_top, dense_original, dense_selected = session.run(
[sparse_tensor, dense_matrix, densified_top])
print(dense_original)
print(dense_selected)
print(sparse_top)
打印:
[[ 1.44056129 -1.01790774 -0.2795608 2.34854746 -2.27528405 -0.62035948
3.36598897]
[ 0.7114948 -0.42564821 -0.93446779 -0.25373486 -0.51730365 0.72331643
-0.75625718]
[-0.6501748 -0.92748415 -0.95409006 -0.07157528 0.80637723 -0.32177576
-1.4516511 ]
[-1.081038 -0.67226124 -1.19455576 0.44537872 -0.69019234 -0.61539739
0.15328468]
[ 0.43032476 -0.11295394 0.83491379 -0.67906654 0.20325914 -0.0155068
0.52107805]]
[[ 0. 0. 0. 2.34854746 0. 0.
3.36598897]
[ 0.7114948 0. 0. 0. 0. 0.72331643
0. ]
[ 0. 0. 0. -0.07157528 0.80637723 0. 0. ]
[ 0. 0. 0. 0.44537872 0. 0.
0.15328468]
[ 0. 0. 0.83491379 0. 0. 0.
0.52107805]]
SparseTensorValue(indices=array([[0, 3],
[0, 6],
[1, 0],
[1, 5],
[2, 3],
[2, 4],
[3, 3],
[3, 6],
[4, 2],
[4, 6]]), values=array([ 2.34854746, 3.36598897, 0.7114948 , 0.72331643, -0.07157528,
0.80637723, 0.44537872, 0.15328468, 0.83491379, 0.52107805], dtype=float32), dense_shape=array([5, 7]))