使用步幅

时间:2017-02-01 22:39:59

标签: python numpy

考虑numpy数组a

a = np.arange(18).reshape(2, 3, 3)
print(a)

[[[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]]

 [[ 9 10 11]
  [12 13 14]
  [15 16 17]]]

我想沿着axis=0填写每个切片的对角线 我使用this answer

完成了这项工作
a[:, np.arange(3), np.arange(3)] = -1

print(a)

[[[-1  1  2]
  [ 3 -1  5]
  [ 6  7 -1]]

 [[-1 10 11]
  [12 -1 14]
  [15 16 -1]]]

我怎么能用步幅做到这一点?

1 个答案:

答案 0 :(得分:1)

这是strides -

的方法
def fill_diag_strided(a, fillval):
    m,n,r = a.strides
    a_diag = np.lib.stride_tricks.as_strided(a, shape=a.shape[:2],strides=(m,r+n))
    a_diag[:] = fillval

示例运行 -

In [369]: a
Out[369]: 
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]]])

In [370]: fill_diag_strided(a, fillval=-1)

In [371]: a
Out[371]: 
array([[[-1,  1,  2],
        [ 3, -1,  5],
        [ 6,  7, -1]],

       [[-1, 10, 11],
        [12, -1, 14],
        [15, 16, -1]]])

针对基于索引的方法进行运行时测试 -

In [35]: a = np.random.rand(100000,3,3)

In [36]: n = a.shape[1]

In [37]: %timeit a[:, np.arange(n), np.arange(n)] = -1
100 loops, best of 3: 4.45 ms per loop

In [38]: %timeit fill_diag_strided(a, fillval=-1)
1000 loops, best of 3: 1.76 ms per loop