二进制矩阵,其中1位于沿行的最大值处

时间:2017-05-26 21:30:52

标签: tensorflow

假设我有一个张量

[[0.3, 0.7],
[0.9,  0.1]]

如何在轴的最大位置创建一个1.0的张量,因此结果应该是轴= 1

[[0., 1.],
[1.,  0.]]

在我的情况下,第一个维度是批量大小,所以它是'?'

2 个答案:

答案 0 :(得分:1)

所提出的两个答案在内存/计算方面都是低效的。

你可以在线性时间(no-matmul)中计算它,而无需在一行中分配不必要的内存:

tf.cast(tf.equal(a, tf.reshape(tf.reduce_max(a, axis=1), (-1, 1))), tf.int16)

完整的例子在这里:

import tensorflow as tf
a = tf.constant([
    [1, 9, 1, 6],
    [6, 5, 0, 6],
    [4, 0, 7, 6],
    [1, 5, 9, 1]
])
b = tf.cast(tf.equal(a, tf.reshape(tf.reduce_max(a, axis=1), (-1, 1))), tf.int16)

with tf.Session() as sess:
    print sess.run(b)

哪个会给你

[[0 1 0 0]
 [1 0 0 1]
 [0 0 1 0]
 [0 0 1 0]]

如您所见,tf.equal使用broadcasting来减少内存数量。

答案 1 :(得分:0)

您可以使用tf.arg_max()索引矩阵的最大值,然后使用tf.sparse_to_dense()派生出您想要的新矩阵。

这是一个例子。

dispose()

' B'是你想要的。您可以通过设置" output_shape',' sparse_values'来获得其他形式的矩阵。和' default_value'。