NumPy中的单位矩阵堆叠

时间:2018-11-06 21:49:02

标签: python numpy

我需要在NumPy中创建一个2n x n矩阵,该矩阵由n x n单位矩阵和彼此堆叠的负n x n单位矩阵组成。

这是我最初的解决方案,效果很好。

def id_stack(n): 
    id_ = np.identity(n) 
    return np.vstack((id_, -id_))

id_stack(3)
# array([[ 1.,  0.,  0.],
#        [ 0.,  1.,  0.],
#        [ 0.,  0.,  1.],
#        [-1., -0., -0.],
#        [-0., -1., -0.],
#        [-0., -0., -1.]])

然后我想我可以设置对角线,这样可以更快,也可以。

def id_stack2(n): 
    full = np.zeros((2*n, n)) 
    rng = np.arange(n) 
    full[rng, rng] = 1 
    full[rng + n, rng] = -1     
    return full

我想知道是否有更快的方法来完成此操作,也许使用某种跨步技巧?

1 个答案:

答案 0 :(得分:1)

从您自己的示例中可能已经注意到,分配一个大缓冲区并在其中设置元素通常比分配两个较小的缓冲区和一个大缓冲区将它们复制到其中要快。

关于numpy的整洁之处在于,您可以在不分配新数组的情况下将视图移到同一缓冲区。例如:

 output = np.zeros((2 * n, n))

在这种情况下有用的视图是

flat = output.ravel()

您可以将第n + 1个元素从第一个开始设置为1,在展平视图中总共n个元素,对于-1则类似。这仅需要在经过筛选的视图上执行简单的索引操作:

output[:n * n:n + 1] = 1
output[n * n::n + 1] = -1

这避免了创建完整范围的数组,并避免了触发高级索引语义,这也需要占用更多的内存。