如何使用numpy访问矩阵的相邻元素?

时间:2019-09-28 10:51:40

标签: python numpy

我编写了一个计算流体溶解度的代码,问题是该代码非常差,所以我一直在用numpy进行查看,我可以对其进行优化,但是我一直不知道如何执行以下操作使用numpy和roll函数的代码。基本上,我有一个索引不能大于1024的矩阵,为此,我使用%来计算它是什么索引。但这需要很长时间。

我尝试使用numpy,使用roll,旋转矩阵,然后不必计算模块。但是我不知道如何取用邻居的价值观。

def evolve(grid, dt, D=1.0):
  xmax, ymax = grid_shape
  new_grid = [[0.0,] * ymax for x in range(xmax)]
  for i in range(xmax):
    for j in range(ymax):
      grid_xx = grid[(i+1)%xmax][j] + grid[(i-1)%xmax][j] - 2.0 * grid[i][j]
      grid_yy = grid[i][(j+1)%ymax] + grid[i][(j-1)%ymax] - 2.0 * grid[i][j]
      new_grid[i][j] = grid[i][j] + D * (grid_xx + grid_yy) * dt
  return new_grid 

1 个答案:

答案 0 :(得分:0)

您必须使用evolve从(几乎)零重写numpy函数。

这里是准则:

  • 首先,grid必须是2D numpy数组,而不是列表列表。
  • 您的老师建议使用roll功能:查看其docs并尝试了解其工作原理。 roll将通过在一个轴上移动(或滚动)矩阵来解决在矩阵中查找邻居条目的问题。然后,您可以在四个方向上创建grid的偏移版本并使用它们,而不用搜索邻居。
  • 一旦移动了grid,您将看到不需要for循环来计算new_grid的每个像元:可以使用矢量化计算,速度更快。 li>

所以代码将如下所示:

def evolve(grid, dt, D=1.0):
    if not isinstance(grid, np.ndarray): #ensuring that is a numpy array.
        grid = np.array(grid)
    u_grid = np.roll(grid, 1, axis=0)
    d_grid = np.roll(grid, -1, axis=0)
    r_grid = np.roll(grid, 1, axis=1)
    l_grid = np.roll(grid, -1, axis=1)
    new_grid = grid + D * (u_grid + d_grid + r_grid + l_grid - 4.0*grid) * dt
    return new_grid

对于1024 x 1024的矩阵,每个小数evolve(在我的机器上)大约需要0.15秒才能返回new_grid。您的evolve和for循环大约需要3.85秒。