在方形矩阵(n,n)的帖子Get the diagonal of a matrix in TensorFlow中,一个sugestion使用函数tf.diag_part(张量)。但如果张量是尺寸(k,n,n)?存在任何方式吗?必要的输出是张量的k方矩阵(n,n)的k对角线,也就是说,我需要一个维数(k,n)的输出。一些建议?
答案 0 :(得分:0)
您可以使用documentation中的tf.map_fn
:
在尺寸0上从elems解包的张量列表上进行映射。
所以你只需要映射tf.diag_part
:
a = tf.placeholder(shape=[100, 10, 10], dtype=tf.float32)
diags = tf.map_fn(tf.diag_part, a, parallel_iterations=100)
diags
的形状为(100, 10)
注意:parallel_iterations
理想情况下应该等于k
,以获得最佳效果。