有没有办法在TensorFlow或Keras中提取(k,n,n)张量的k对角线?

时间:2018-04-14 03:29:48

标签: python tensorflow keras diagonal

在方形矩阵(n,n)的帖子Get the diagonal of a matrix in TensorFlow中,一个sugestion使用函数tf.diag_part(张量)。但如果张量是尺寸(k,n,n)?存在任何方式吗?必要的输出是张量的k方矩阵(n,n)的k对角线,也就是说,我需要一个维数(k,n)的输出。一些建议?

1 个答案:

答案 0 :(得分:0)

您可以使用documentation中的tf.map_fn

  

在尺寸0上从elems解包的张量列表上进行映射。

所以你只需要映射tf.diag_part

a = tf.placeholder(shape=[100, 10, 10], dtype=tf.float32)
diags = tf.map_fn(tf.diag_part, a, parallel_iterations=100)

diags的形状为(100, 10)

注意:parallel_iterations理想情况下应该等于k,以获得最佳效果。