假设我有一个张量
[[0.3, 0.7],
[0.9, 0.1]]
如何在轴的最大位置创建一个1.0的张量,因此结果应该是轴= 1
[[0., 1.],
[1., 0.]]
在我的情况下,第一个维度是批量大小,所以它是'?'
答案 0 :(得分:1)
所提出的两个答案在内存/计算方面都是低效的。
你可以在线性时间(no-matmul)中计算它,而无需在一行中分配不必要的内存:
tf.cast(tf.equal(a, tf.reshape(tf.reduce_max(a, axis=1), (-1, 1))), tf.int16)
完整的例子在这里:
import tensorflow as tf
a = tf.constant([
[1, 9, 1, 6],
[6, 5, 0, 6],
[4, 0, 7, 6],
[1, 5, 9, 1]
])
b = tf.cast(tf.equal(a, tf.reshape(tf.reduce_max(a, axis=1), (-1, 1))), tf.int16)
with tf.Session() as sess:
print sess.run(b)
哪个会给你
[[0 1 0 0]
[1 0 0 1]
[0 0 1 0]
[0 0 1 0]]
如您所见,tf.equal
使用broadcasting来减少内存数量。
答案 1 :(得分:0)
您可以使用tf.arg_max()索引矩阵的最大值,然后使用tf.sparse_to_dense()派生出您想要的新矩阵。
这是一个例子。
dispose()
' B'是你想要的。您可以通过设置" output_shape',' sparse_values'来获得其他形式的矩阵。和' default_value'。