考虑numpy数组a
a = np.arange(18).reshape(2, 3, 3)
print(a)
[[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]]
[[ 9 10 11]
[12 13 14]
[15 16 17]]]
我想沿着axis=0
填写每个切片的对角线
我使用this answer
a[:, np.arange(3), np.arange(3)] = -1
print(a)
[[[-1 1 2]
[ 3 -1 5]
[ 6 7 -1]]
[[-1 10 11]
[12 -1 14]
[15 16 -1]]]
我怎么能用步幅做到这一点?
答案 0 :(得分:1)
这是strides
-
def fill_diag_strided(a, fillval):
m,n,r = a.strides
a_diag = np.lib.stride_tricks.as_strided(a, shape=a.shape[:2],strides=(m,r+n))
a_diag[:] = fillval
示例运行 -
In [369]: a
Out[369]:
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]])
In [370]: fill_diag_strided(a, fillval=-1)
In [371]: a
Out[371]:
array([[[-1, 1, 2],
[ 3, -1, 5],
[ 6, 7, -1]],
[[-1, 10, 11],
[12, -1, 14],
[15, 16, -1]]])
针对基于索引的方法进行运行时测试 -
In [35]: a = np.random.rand(100000,3,3)
In [36]: n = a.shape[1]
In [37]: %timeit a[:, np.arange(n), np.arange(n)] = -1
100 loops, best of 3: 4.45 ms per loop
In [38]: %timeit fill_diag_strided(a, fillval=-1)
1000 loops, best of 3: 1.76 ms per loop