具有所有可能移动的滚动矩阵

时间:2019-02-24 12:10:01

标签: python numpy vectorization

我有一个函数,它接收一个(h, w)矩阵并生成一个3阶张量,该张量对于每一行都包含该行大小c的所有可能移位。这是功能和用法示例

def slow_matrix_roll(v, c):
    h, w = v.shape

    v = np.pad(v, ((0, 0), (0, w - 1)), mode="constant")
    res = np.zeros((h, w, c))

    for j in range(w):
        res[:, j, :] = v[:, j:j+c]

    return res 

inp = np.arange(1,10).reshape(3, 3)
res = slow_matrix_roll(inp, 2)

print(res.shape)
print(res)

输入:

[[1 2 3]
 [4 5 6]
 [7 8 9]]

输出:

(3, 3, 2)
[[[1. 2.]
  [2. 3.]
  [3. 0.]]

 [[4. 5.]
  [5. 6.]
  [6. 0.]]

 [[7. 8.]
  [8. 9.]
  [9. 0.]]]

例如,输入[1, 2, 3]c = 2的第一行将产生一个矩阵

1 2
2 3
3 0

,并且每一行都会发生这种情况,从而导致张量。

我的问题是,如何使其更快?我想理想情况下,我想摆脱for循环,但是欢迎任何更快的解决方案。

1 个答案:

答案 0 :(得分:1)

您可以使用stride_tricks

def fast_roll(v, c):
    *h, w = v.shape
    V = np.zeros((*h, w+c-1), v.dtype)
    V[..., :w] = v
    return np.lib.stride_tricks.as_strided(V, (*h, w, c), (*V.strides, V.strides[-1]))

请注意,这将创建非连续视图。如果需要,请制作连续副本。

示例:

>>> fast_roll(np.arange(9).reshape(3, 3), 2)
array([[[0, 1],
        [1, 2],
        [2, 0]],

       [[3, 4],
        [4, 5],
        [5, 0]],

       [[6, 7],
        [7, 8],
        [8, 0]]])