Python Pandas:查找包含numpy数组

时间:2016-12-12 20:18:49

标签: python pandas numpy

我有一个Pandas DataFrame,如下所示:

      values                                      max_val_idx
0    np.array([-0.649626, -0.662434, -0.611351])            2
1    np.array([-0.994942, -0.990448, -1.01574])             1
2    np.array([-1.012, -1.01034, -1.02732])                 0

df['values']包含固定长度为3个元素的numpy数组 df['max_val_idx]包含相应数组

的最大值的索引

由于已经给出了每个数组的最大元素的索引,因此提取每个条目的最大值的最有效方法是什么?
我知道数据存储有点傻,但我自己并没有创建它。而且由于我需要处理大量数据(+ - 50GB,数百个以类似方式存储的数据库),我想知道什么是最有效的方法。

到目前为止,我尝试遍历df['max_val_idx]的每个元素,并将其用作df['values']中找到的每个数组的索引:

max_val = []         
for idx, values in enumerate(df['values']):
     max_val.append(values[int(df['max_val_idx'].iloc[idx])])

还有更快的替代方案吗?

2 个答案:

答案 0 :(得分:4)

我会忘记'max_val_idx'列。我不认为它节省了时间,实际上更多的是语法上的痛苦。样本数据:

df = pd.DataFrame({ 'x': range(3) }).applymap( lambda x: np.random.randn(3) )

                                                   x
0  [-1.17106202376, -1.61211460669, 0.0198122724315]
1    [0.806819945736, 1.49139051675, -0.21434675401]
2  [-0.427272615966, 0.0939459129359, 0.496474566...

您可以像这样提取最大值:

df.applymap( lambda x: x.max() )

          x  
0  0.019812
1  1.491391
2  0.496475

但总的来说,如果每个细胞有一个数字,生活会更容易。如果每个单元格都有一个长度为3的数组,则可以重新排列如下:

for i, v in enumerate(list('abc')): df[v] = df.x.map( lambda x: x[i] )
df = df[list('abc')]

          a         b         c
0 -1.171062 -1.612115  0.019812
1  0.806820  1.491391 -0.214347
2 -0.427273  0.093946  0.496475

然后做一个标准的熊猫操作:

df.apply( max, axis=1 )

          x  
0  0.019812
1  1.491391
2  0.496475

不可否认,这并不比上面容易,但总体而言,这种形式的数据更容易使用。

答案 1 :(得分:2)

我不知道如何比较它的速度,因为我构建了所有行的2D矩阵,但这是一个可能的解决方案:

>>> np.choose(df['max_val_idx'], np.array(df['values'].tolist()).T)
0   -0.611351
1   -0.990448
2   -1.012000