优化光线跟踪的numpy数组操作

时间:2012-10-02 15:13:49

标签: python numpy

我有一些代码可以在地图中进行基本光线追踪,以确定光线是否撞到了墙壁。

[编辑]:y_coords和x_coords的大小通常为18x1000(对应点)。 self.map是800x800

def ray_trace(self, x, y, x_coords, y_coords):
    ray_distances = []
    resolution = self.parameters['resolution']
    for i in range(x_coords.shape[0]):
        distance = 0

        # filter x and y coords to stay within map regions
        ray_range = np.bitwise_and(x_coords[i]<799,y_coords[i]<799)

        # determine ending index where the ray stops
        len_ray = ray_range[ray_range==True].shape[0]

        # zip up the x and y coords
        ray_coords = np.c_[x_coords[i,0:len_ray], y_coords[i,0:len_ray]]

        # look up all the coordinates in the map and find where the map is
        # less than or equal to zero (this is a wall)
        ray_values, = np.where(self.map[tuple(ray_coords.T)] <= 0)

        # some special exceptions
        if not ray_values.shape[0]:
            if not len(ray_coords):
                end_of_ray = np.array([x/resolution, y/resolution])
            else:
                end_of_ray = ray_coords[len(ray_values)]
        else:
            # get the end of the ray
            end_of_ray = ray_coords[ray_values.item(0)]

        # find the distance from the originating point
        distance = math.sqrt((end_of_ray.item(0) - x/resolution)**2 + 
                             (end_of_ray.item(1) - y/resolution)**2)

        ray_distances.append(distance)
    return ray_distances

我在np.c_和np.where行中遇到了问题 - 我用kernprof.py描述了它们和那些行,并且它们花费了很长时间(尤其是np.c_,它占用了50%的时间)。有没有人对如何优化这个有任何想法?

1 个答案:

答案 0 :(得分:2)

你真的不需要那么多玩指数。高级索引意味着您可以使用两个大小相等的坐标数组进行索引,而无需先将它们组合成坐标。

coord_mask = (x_coords < 799) & (y_coords < 799)
for i in xrange(len(coord_mask)):
    distance = 0
    row_mask = coord_mask[i]
    row_x = x_coords[i, row_mask]
    row_y = y_coords[i, row_mask]
    mapvals = self.map[row_x, row_y] # advanced indexing
    ray_values, = (mapvals <= 0).nonzero()
    ...