我有一个2D张量,各种数组定义为:
x = tf.constant([[0,1,2],[-1,0,1],[-1,-2,0]])
我希望将每个数组转换为对角矩阵:
diag_x =
[[[ 0, 0, 0],
[ 0, 1, 0],
[ 0, 0, 2]],
[[-1, 0, 0],
[ 0, 0, 0],
[ 0, 0, 1]],
[[-2, 0, 0],
[ 0, -1, 0],
[ 0, 0, 0]]]
但如果我使用 tf.diag(x)操作,则输出不是这个。
答案 0 :(得分:3)
您可以尝试:
tf.matrix_set_diag(tf.zeros((3,3,3), dtype=tf.int32), x)
答案 1 :(得分:2)
我终于找到了解决方案:
tf.matrix_diag(x)