大熊猫:仅保留前n个值,并将其他值设置为0

时间:2018-11-06 09:49:05

标签: python pandas

在熊猫数据框中,对于每一行,我只想保留前N个值,并将其他所有值都设置为0。 我可以遍历各行并做到这一点,但是我确信python / pandas可以在一行中优雅地做到这一点。

例如:对于N = 2

Input:
A   B   C   D
4   10  10  6
5   20  50  90
6   30  6   4
7   40  12  9

Output:
A   B   C   D
0   10  10  0
0   0   50  90
6   30  6   0
0   40  12  0

4 个答案:

答案 0 :(得分:5)

rank与参数axis=1method='min'ascending=False一起使用:

N = 2
df = df.mask(df.rank(axis=1, method='min', ascending=False) > N, 0)

或者将np.wherepd.DataFrame一起使用,它比mask方法要快:

df = pd.DataFrame(np.where(df.rank(axis=1,method='min',ascending=False)>N, 0, df),
                  columns=df.columns)

print(df)
   A   B   C   D
0  0  10  10   0
1  0   0  50  90
2  6  30   6   0
3  0  40  12   0

说明:

步骤1: 首先,我们需要找出该行中2个最小的数字是什么,以及是否有重复项需要考虑。因此,使用axis=1跨行进行排名,method='min'ascending = False将注意重复的值:

print(df.rank(axis=1, method='min', ascending=False))
     A    B    C    D
0  4.0  1.0  1.0  3.0
1  4.0  3.0  2.0  1.0
2  2.0  1.0  2.0  4.0
3  4.0  1.0  2.0  3.0

步骤2:其次,我们需要根据条件过滤值大于(N)的位置,然后使用mask更改这些值:

print(df.rank(axis=1, method='min', ascending=False) > N)
       A      B      C      D
0   True  False  False   True
1   True   True  False  False
2  False  False  False   True
3   True  False  False   True

print(df.mask(df.rank(axis=1, method='min', ascending=False) > N, 0))
   A   B   C   D
0  0  10  10   0
1  0   0  50  90
2  6  30   6   0
3  0  40  12   0

答案 1 :(得分:3)

使用:

N = 2
df = df.where(df.apply(lambda x: x.isin(x.nlargest(N)), axis=1), 0)
print (df)
   A   B   C   D
0  0  10  10   0
1  0   0  50  90
2  6  30   6   0
3  0  40  12   0

或者:

import heapq
N = 2
df = df.where(df.apply(lambda x: x.isin(heapq.nlargest(N, x)), axis=1), 0)
print (df)
   A   B   C   D
0  0  10  10   0
1  0   0  50  90
2  6  30   6   0
3  0  40  12   0

答案 2 :(得分:1)

使用nlargest获得N个最大数字:

df.mask(~df.apply(lambda x: x.isin(x.nlargest(2)), axis=1), 0)

Outpu:

    A   B   C   D
0   0   10  10  0
1   0   0   50  90
2   6   30  6   0
3   0   40  12  0

答案 3 :(得分:1)

您可以通过scipy.stats.rankdata使用np.apply_along_axis,并输入pd.DataFrame.where

if(isset($_POST['submit'])) {
    $username = $_POST['username'];
    $_SESSION['username'] = $username;
}

性能基准测试

pd.DataFrame.rank是以下解决方案中效率最高的; from scipy.stats import rankdata df[:] = df.where(np.apply_along_axis(rankdata, 1, df, method='max') > 2, 0) print(df) A B C D 0 0 10 10 0 1 0 0 50 90 2 6 30 6 0 3 0 40 12 0 + apply表现最差。

lambda