在Pandas DataFrame中查找第一列匹配条件的矢量化方法

时间:2019-01-03 19:57:39

标签: python pandas

假设我有以下熊猫DataFrame:

          A         B         C
0  0.548814  0.791725  0.978618
1  0.715189  0.528895  0.799159
2  0.602763  0.568045  0.461479
3  0.544883  0.925597  0.780529
4  0.423655  0.071036  0.118274
5  0.645894  0.087129  0.639921
6  0.437587  0.020218  0.143353
7  0.891773  0.832620  0.944669
8  0.963663  0.778157  0.521848
9  0.383442  0.870012  0.414662

可以使用以下代码创建

import pandas as pd
import numpy as np

size = 10
np.random.seed(0)
keys = ["A", "B", "C"]
df = pd.DataFrame({k: np.random.random(size) for k in keys})

如何找到符合指定条件的第一个

在这种情况下,假设我的标准是我想要第一列中的值小于某个p,例如0.5。如果没有列符合此条件,我想返回"No Match"

使用apply,可以这样做如下:

p = 0.5
first = df.apply(
    lambda row: next((x for i, x in enumerate(df.columns) if row[x]<p), "No Match"), 
    axis=1
)
print(first)
#0    No Match
#1    No Match
#2           C
#3    No Match
#4           A
#5           B
#6           A
#7    No Match
#8    No Match
#9           A
#dtype: object

是否有一种更有效的(矢量化)方法?我在想应该使用argmax()的某种方式,但是我还没有开始使用它。

此外,我使用的是熊猫0.19.2,我不确定是否可以升级。

print(pd.__version__)
#u'0.19.2'

2 个答案:

答案 0 :(得分:3)

您可以使用NumPy argmax,但需要覆盖给定行中从不满足条件的实例:

mask = df.lt(0.5)
df['first'] = np.where(mask.any(1), df.columns[mask.values.argmax(1)], 'No Match')

您也可以使用熊猫idxmax

df['first'] = np.where(mask.any(1), mask.idxmax(1), 'No Match')

print(df)

          A         B         C     first
0  0.548814  0.791725  0.978618  No Match
1  0.715189  0.528895  0.799159  No Match
2  0.602763  0.568045  0.461479         C
3  0.544883  0.925597  0.780529  No Match
4  0.423655  0.071036  0.118274         A
5  0.645894  0.087129  0.639921         B
6  0.437587  0.020218  0.143353         A
7  0.891773  0.832620  0.944669  No Match
8  0.963663  0.778157  0.521848  No Match
9  0.383442  0.870012  0.414662         A

答案 1 :(得分:3)

IIUC与dot

df.lt(0.5).dot(df.columns).str[0].fillna('notmatch')
Out[167]: 
0    notmatch
1    notmatch
2           C
3    notmatch
4           A
5           B
6           A
7    notmatch
8    notmatch
9           A
dtype: object