熊猫:具有两个数据帧功能的矢量化

时间:2018-08-24 03:47:02

标签: python pandas vectorization

我在熊猫中实现矢量化遇到麻烦。首先,我要说我是向量化的新手,因此很有可能我弄错了一些语法。

假设我有两个熊猫数据框。

数据框一描述了一些半径为R且具有唯一ID的圆的x,y坐标。

>>> data1 = {'ID': [1, 2], 'x': [1, 10], 'y': [1, 10], 'R': [4, 5]}
>>> df_1=pd.DataFrame(data=data1)
>>>
>>> df_1
   ID  x   y   R
   1   1   1   4
   2   10  10  5

数据框2描述了某些点的x,y坐标,也具有唯一的ID。

>>> data2 = {'ID': [3, 4, 5], 'x': [1, 3, 9], 'y': [2, 5, 9]}
>>> df_2=pd.DataFrame(data=data2)
>>>
>>> df_2
   ID  x  y
   3   1  2
   4   3  5
   5   9  9

现在,想象一下在2D平面上绘制圆和点。一些要点将位于圆圈内。请参见下图。

enter image description here

我要做的就是在df_2中创建一个名为“ host_circle”的新列,该列指示每个点所在的圆的ID。如果粒子不在圆中,则该值应为“ None”。

我想要的输出是

>>> df_2
   ID  x  y   host_circle
   3   1  2   1 
   4   3  5   None 
   5   9  9   2

首先,定义一个检查给定粒子(x2,y2)是否位于给定圆(x1,y1,R1,ID_1)内的函数。如果是,则返回圆的ID;否则,返回0。否则,返回None。

>>> def func(x1,y1,R1,ID_1,x2,y2):
...     dist = np.sqrt( (x1-x2)**2 + (y1-y2)**2 )
...     if dist < R:
...         return ID_1
...     else:
...        return None

接下来,是实际的向量化。我有点迷路了。我认为应该是

df_2['host']=func(df_1['x'],df_1['y'],df_1['R'],df_1['ID'],df_2['x'],df_2['y'])

但这只会引发错误。有人可以帮我吗?

最后一点:我正在使用的实际数据非常大;数千万行。速度至关重要,因此为什么我要尝试向量化。

2 个答案:

答案 0 :(得分:5)

Numba v1

您可能必须使用

安装numba
pip install numba

然后通过numba函数装饰器使用njit的jit编译器

from numba import njit

@njit
def distances(point, points):
  return ((points - point) ** 2).sum(1) ** .5

@njit
def find_my_circle(point, circles):
  points = circles[:, :2]
  radii = circles[:, 2]
  dist = distances(point, points)
  mask = dist < radii
  i = mask.argmax()
  return i if mask[i] else -1

@njit
def find_my_circles(points, circles):
  n = len(points)
  out = np.zeros(n, np.int64)
  for i in range(n):
    out[i] = find_my_circle(points[i], circles)
  return out

ids = np.append(df_1.ID.values, np.nan)

i = find_my_circles(points, df_1[['x', 'y', 'R']].values)
df_2['host_circle'] = ids[i]

df_2

   ID  x  y  host_circle
0   3  1  2          1.0
1   4  3  5          NaN
2   5  9  9          2.0

此操作逐行迭代...意味着每次尝试查找宿主圆时都指向一个点。现在,该部分仍被矢量化。 并且,循环应该非常快。这样做的好处是您不必占用大量内存。


Numba v2

这是一个循环的但找到主机时会短路

from numba import njit

@njit
def distance(a, b):
  return ((a - b) ** 2).sum() ** .5

@njit
def find_my_circles(points, circles):
  n = len(points)
  m = len(circles)

  out = -np.ones(n, np.int64)

  centers = circles[:, :2]
  radii = circles[:, 2]

  for i in range(n):
    for j in range(m):
      if distance(points[i], centers[j]) < radii[j]:
        out[i] = j
        break

  return out

ids = np.append(df_1.ID.values, np.nan)

i = find_my_circles(points, df_1[['x', 'y', 'R']].values)
df_2['host_circle'] = ids[i]

df_2

矢量化

但仍然有问题

c = ['x', 'y']
centers = df_1[c].values
points = df_2[c].values
radii = df_1['R'].values

i, j = np.where(((points[:, None] - centers) ** 2).sum(2) ** .5 < radii)

df_2.loc[df_2.index[i], 'host_circle'] = df_1['ID'].iloc[j].values

df_2

   ID  x  y  host_circle
0   3  1  2          1.0
1   4  3  5          NaN
2   5  9  9          2.0

说明

从圆心到任意点的距离是

((x1 - x0) ** 2 + (y1 - y0) ** 2) ** .5

如果我将一个数组扩展到第三维,就可以使用广播

points[:, None] - centers

array([[[ 0,  1],
        [-9, -8]],

       [[ 2,  4],
        [-7, -5]],

       [[ 8,  8],
        [-1, -1]]])

这是向量差的所有六个组合。现在计算距离。

((points[:, None] - centers) ** 2).sum(2) ** .5

array([[ 1.        , 12.04159458],
       [ 4.47213595,  8.60232527],
       [11.3137085 ,  1.41421356]])

那是距离的所有6种组合,我可以将其与半径进行比较以查看哪些在圆内

((points[:, None] - centers) ** 2).sum(2) ** .5 < radii

array([[ True, False],
       [False, False],
       [False,  True]])

好的,我想找到True值在哪里。这是np.where的完美用例。它会给我两个数组,第一个是行位置,第二个是这些True值所在的列位置。事实证明,行位置是points,列位置是圆圈。

i, j = np.where(((points[:, None] - centers) ** 2).sum(2) ** .5 < radii)

现在,我只需要以某种方式将df_2切成i,并以某种方式分配使用df_1j得到的值……但是我在上面展示了这一点。

答案 1 :(得分:0)

尝试一下。我已经对您的函数进行了一些修改以进行计算,并且假设有很多圆满足一个点,我将获得此列表。如果不是这种情况,可以修改它。如果粒子不在任何一个圆中,它也会是零成员列表

def func(df, x2,y2):
    val = df.apply(lambda row: np.sqrt((row['x']-x2)**2 + (row['y']-y2)**2) < row['R'], axis=1)
    return list(val.index[val==True])

df_2['host'] = df_2.apply(lambda row: func(df_1, row['x'],row['y']), axis=1)