我想创建一个M * N张量,其中所有元素都是零,除了每行一个随机元素,它应该是一个,但我不知道如何。
答案 0 :(得分:0)
这是一种方法:
import tensorflow as tf
m = 4
n = 6
dt = tf.float32
random_idx = tf.random_uniform((m, 1), maxval=n, dtype=tf.int32)
result = tf.cast(tf.equal(tf.range(n)[tf.newaxis], random_idx), dtype=dt)
with tf.Session() as sess:
print(sess.run(result))
输出:
[[ 0. 0. 0. 0. 0. 1.]
[ 0. 0. 1. 0. 0. 0.]
[ 0. 1. 0. 0. 0. 0.]
[ 0. 1. 0. 0. 0. 0.]]