我有一个数组
sizeof
如何用y替换X的对角线。我们可以编写一个循环,但是有任何更快的方法吗?
答案 0 :(得分:3)
这应该非常快(特别是对于较大的数组,对于您的示例来说,速度要慢大约两倍):
arr = np.zeros((4,4))
replace = [1,2,3,4]
l = len(arr)
arr.shape = -1
arr[::l+1] = replace
arr.shape = l,l
在更大的数组上测试:
n = 100
arr = np.zeros((n,n))
replace = np.ones(n)
def loop():
for i in range(len(arr)):
arr[i,i] = replace[i]
def other():
l = len(arr)
arr.shape = -1
arr[::l+1] = replace
arr.shape = l,l
%timeit(loop())
%timeit(other())
14.7 µs ± 112 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
1.55 µs ± 24.1 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
答案 1 :(得分:3)
一种快速可靠的方法是np.einsum
:
>>> diag_view = np.einsum('ii->i', X)
这将创建对角线视图:
>>> diag_view
array([0.7513, 0.8909, 0.1493, 0.2543])
此视图可写:
>>> diag_view[None] = y
>>> X
array([[1. , 0.6991, 0.5472, 0.2575],
[0.2551, 2. , 0.1386, 0.8407],
[0.506 , 0.9593, 3. , 0.2543],
[0.506 , 0.9593, 0.1493, 4. ]])
这适用于连续和非连续数组,并且速度非常快:
contiguous:
loop 21.146424998732982
diag_indices 2.595232878000388
einsum 1.0271988900003635
flatten 1.5372659160002513
non contiguous:
loop 20.133818001340842
diag_indices 2.618005960001028
einsum 1.0305795049989683
Traceback (most recent call last): <- flatten does not work here
...
它如何工作? einsum
进行了@Julien技巧的高级版本:它增加了arr
的步幅:
>>> arr.strides
(3200, 16)
>>> np.einsum('ii->i', arr).strides
(3216,)
一个人可以说服自己,只要arr跨步组织,这将一直有效,这是numpy数组的情况。
虽然einsum
的这种用法非常简洁,但是如果不知道,几乎也找不到。所以传播这个词!
重新创建计时和崩溃的代码:
import numpy as np
n = 100
arr = np.zeros((n, n))
replace = np.ones(n)
def loop():
for i in range(len(arr)):
arr[i,i] = replace[i]
def other():
l = len(arr)
arr.shape = -1
arr[::l+1] = replace
arr.shape = l,l
def di():
arr[np.diag_indices(arr.shape[0])] = replace
def es():
np.einsum('ii->i', arr)[...] = replace
from timeit import timeit
print('\ncontiguous:')
print('loop ', timeit(loop, number=1000)*1000)
print('diag_indices ', timeit(di))
print('einsum ', timeit(es))
print('flatten ', timeit(other))
arr = np.zeros((2*n, 2*n))[::2, ::2]
print('\nnon contiguous:')
print('loop ', timeit(loop, number=1000)*1000)
print('diag_indices ', timeit(di))
print('einsum ', timeit(es))
print('flatten ', timeit(other))
答案 2 :(得分:2)
将diag_indices
用于矢量化解决方案:
X[np.diag_indices(X.shape[0])] = y
array([[1. , 0.6991, 0.5472, 0.2575],
[0.2551, 2. , 0.1386, 0.8407],
[0.506 , 0.9593, 3. , 0.2543],
[0.506 , 0.9593, 0.1493, 4. ]])