获取pandas中最常见的虚拟变量的子集

时间:2013-08-02 12:06:39

标签: python pandas

我正在尝试执行一些线性回归分析,我有一些明确的功能,我使用超级棒的get_dummies转换为虚拟变量。

我面临的问题是,当我添加类别的所有元素时,数据框太大了。

有没有办法(使用get_dummies或更复杂的方法)来创建最常用术语的虚拟变量而不是所有术语?

3 个答案:

答案 0 :(得分:6)

使用value_counts()进行频率计数,然后为要保留的行创建掩码:

import pandas as pd
values = pd.Series(["a","b","a","b","c","d","e","a"])
counts = pd.value_counts(values)
mask = values.isin(counts[counts > 1].index)
print pd.get_dummies(values[mask])

输出:

   a  b
0  1  0
1  0  1
2  1  0
3  0  1
7  1  0

如果您想要所有数据:

values[~mask] = "-"
print pd.get_dummies(values)

输出:

   -  a  b
0  0  1  0
1  0  0  1
2  0  1  0
3  0  0  1
4  1  0  0
5  1  0  0
6  1  0  0
7  0  1  0

答案 1 :(得分:3)

我使用@HYRY给出的答案来编写一个函数,该函数将具有一个参数(阈值),可用于分隔流行值和不受欢迎的值(组合在'其他'列中)。

import pandas as pd
import numpy as np

# func that returns a dummified DataFrame of significant dummies in a given column
def dum_sign(dummy_col, threshold=0.1):

    # removes the bind
    dummy_col = dummy_col.copy()

    # what is the ratio of a dummy in whole column
    count = pd.value_counts(dummy_col) / len(dummy_col)

    # cond whether the ratios is higher than the threshold
    mask = dummy_col.isin(count[count > threshold].index)

    # replace the ones which ratio is lower than the threshold by a special name
    dummy_col[~mask] = "others"

    return pd.get_dummies(dummy_col, prefix=dummy_col.name)
#

让我们创建一些数据:

df = ['a', 'a', np.nan, np.nan, 'a', np.nan, 'a', 'b', 'b', 'b', 'b', 'b', 
             'c', 'c', 'd', 'e', 'g', 'g', 'g', 'g']

data = pd.Series(df, name='dums')

使用示例:

 In: dum_sign(data)
Out:
    dums_a  dums_b  dums_g  dums_others
0        1       0       0            0
1        1       0       0            0
2        0       0       0            1
3        0       0       0            1
4        1       0       0            0
5        0       0       0            1
6        1       0       0            0
7        0       1       0            0
8        0       1       0            0
9        0       1       0            0
10       0       1       0            0
11       0       1       0            0
12       0       0       0            1
13       0       0       0            1
14       0       0       0            1
15       0       0       0            1
16       0       0       1            0
17       0       0       1            0
18       0       0       1            0
19       0       0       1            0

 In: dum_sign(data, threshold=0.2)
Out: 
    dums_b  dums_others
0        0            1
1        0            1
2        0            1
3        0            1
4        0            1
5        0            1
6        0            1
7        1            0
8        1            0
9        1            0
10       1            0
11       1            0
12       0            1
13       0            1
14       0            1
15       0            1
16       0            1
17       0            1
18       0            1
19       0            1

 In: dum_sign(data, threshold=0)
Out: 
    dums_a  dums_b  dums_c  dums_d  dums_e  dums_g  dums_others
0        1       0       0       0       0       0            0
1        1       0       0       0       0       0            0
2        0       0       0       0       0       0            1
3        0       0       0       0       0       0            1
4        1       0       0       0       0       0            0
5        0       0       0       0       0       0            1
6        1       0       0       0       0       0            0
7        0       1       0       0       0       0            0
8        0       1       0       0       0       0            0
9        0       1       0       0       0       0            0
10       0       1       0       0       0       0            0
11       0       1       0       0       0       0            0
12       0       0       1       0       0       0            0
13       0       0       1       0       0       0            0
14       0       0       0       1       0       0            0
15       0       0       0       0       1       0            0
16       0       0       0       0       0       1            0
17       0       0       0       0       0       1            0
18       0       0       0       0       0       1            0
19       0       0       0       0       0       1            0

有关如何处理nans的任何建议?我认为nans不应该被视为'其他人。

UPD:我在一个相当大的数据集(5 mil obs)上测试了它,在我想要实现的列中有183个不同的字符串。我的笔记本电脑最多需要10秒钟。

答案 2 :(得分:1)

您可以先使用value_counts查看最常见的内容:

In [11]: s = pd.Series(list('aabccc'))

In [12]: s
Out[12]: 
0    a
1    a
2    b
3    c
4    c
5    c
dtype: object

In [13]: s.value_counts()
Out[13]: 
c    3
a    2
b    1
dtype: int64

最不频繁的值(例如除了前两个之外的所有值):

In [14]: s.value_counts().index[2:]
Out[14]: Index([u'b'], dtype=object)

您可以使用NaN简单地replace出现所有这些:

In [15]: s1 = s.replace(s.value_counts().index[2:], np.nan)

In [16]: s1
Out[16]: 
0      a
1      a
2    NaN
3      c
4      c
5      c
dtype: object

并执行get_dummies(我认为应该忽略NaN,但是有一个错误,因此notnull hack):

In [16]: pd.get_dummies(s1[s1.notnull()])
Out[16]: 
   a  c
0  1  0
1  1  0
3  0  1
4  0  1
5  0  1

如果您想要包含这些结果,可以使用其他占位符(例如'_')。