批处理的numpy数组索引

时间:2016-12-21 09:50:59

标签: python numpy

我有源和目标numpy数组,比如说

dest = np.arange(1000)
src = np.random(500)

我希望在src中复制到dest数组,覆盖前50个数字,然后是从索引100到150开始的所有数字,然后是200到250,依此类推,直到900到950.

要计算索引,我使用以下内容:

index = np.reshape([100*i+np.arange(50) for i in range(10)],-1)

然后我只使用

dest[index]=src

复制剩余的元素(与另一个src,比如src2),我只是调整索引

index2 = np.reshape([50+100*i+np.arange(50) for i in range(10)],-1)
dest[index2]=src2

我非常确定这是一种更优雅/更有效的方法,无需显式构建索引。

有没有更好的方法来执行复制?

2 个答案:

答案 0 :(得分:1)

dest.reshape(-1,50)[::2] = src.reshape(-1,50)

对于src2

dest.reshape(-1,50)[1::2] = src2.reshape(-1,50)

答案 1 :(得分:1)

我想最简单的方法是将两者重塑为矩阵

dest = dest.reshape((10, 100))
src = src.reshape((10, 50))

并使用矩阵索引

dest[:,:50] = src

我们的想法是创建一个100 x 10的dest矩阵,并用src替换一半转换为50 x 10矩阵。

修改

@Haminaa的答案稍快一些。

%timeit dest.reshape((10, -1))[:,:50] = src.reshape((10, -1))
  

最慢的跑步比最快跑的时间长11.54倍。这可以   表示正在缓存中间结果。 1000000循环,最好   3:每循环1.55μs

%timeit dest.reshape(-1,50)[::2] = src.reshape(-1,50)
  

最慢的跑步比最快跑的时间长12.97倍。这可以   表示正在缓存中间结果。 1000000循环,最好   3:每循环1.38μs