指定轴的numpy`diagflat`

时间:2018-09-21 11:58:52

标签: python numpy linear-algebra

我想以ndarray的形式将diagflat的轴扩展为对角矩阵

例如

In [ ]: import numpy as np

In [ ]: example = np.random.random((200, 5))
In [ ]: example.shape
Out[ ]: (200, 5)

我正在寻找的东西是这样的:

In [ ]: np.diagflat(example, axis=-1).shape
Out[ ]: (200, 5, 5)

diagflataxis参数。我的想法是例如简单地插入一个新轴并将其与单位矩阵相乘。

In [ ]: Id = np.eye(example.shape[-1])
In [ ]: (example[..., np.newaxis] @ Id).shape
ValueError: shapes (200,5,1) and (5,5) not aligned: 1 (dim 2) != 5 (dim 0)

然而,这引起了错误,显然广播不应用于矩阵乘法。是否有一个优雅的解决方案,还是我必须手动创建并填充数组?

1 个答案:

答案 0 :(得分:1)

只需:

example[..., np.newaxis] * Id