我有一个函数,它接收一个(h, w)
矩阵并生成一个3阶张量,该张量对于每一行都包含该行大小c
的所有可能移位。这是功能和用法示例
def slow_matrix_roll(v, c):
h, w = v.shape
v = np.pad(v, ((0, 0), (0, w - 1)), mode="constant")
res = np.zeros((h, w, c))
for j in range(w):
res[:, j, :] = v[:, j:j+c]
return res
inp = np.arange(1,10).reshape(3, 3)
res = slow_matrix_roll(inp, 2)
print(res.shape)
print(res)
输入:
[[1 2 3]
[4 5 6]
[7 8 9]]
输出:
(3, 3, 2)
[[[1. 2.]
[2. 3.]
[3. 0.]]
[[4. 5.]
[5. 6.]
[6. 0.]]
[[7. 8.]
[8. 9.]
[9. 0.]]]
例如,输入[1, 2, 3]
与c = 2
的第一行将产生一个矩阵
1 2
2 3
3 0
,并且每一行都会发生这种情况,从而导致张量。
我的问题是,如何使其更快?我想理想情况下,我想摆脱for循环,但是欢迎任何更快的解决方案。
答案 0 :(得分:1)
您可以使用stride_tricks
。
def fast_roll(v, c):
*h, w = v.shape
V = np.zeros((*h, w+c-1), v.dtype)
V[..., :w] = v
return np.lib.stride_tricks.as_strided(V, (*h, w, c), (*V.strides, V.strides[-1]))
请注意,这将创建非连续视图。如果需要,请制作连续副本。
示例:
>>> fast_roll(np.arange(9).reshape(3, 3), 2)
array([[[0, 1],
[1, 2],
[2, 0]],
[[3, 4],
[4, 5],
[5, 0]],
[[6, 7],
[7, 8],
[8, 0]]])