Pandas groupby& linregress如何提取

时间:2016-05-07 04:17:49

标签: python pandas

我按组对数据进行线性回归以生成摘要统计信息。我使用scipy linregress计算了两个变量km和price的回归:

import pandas as pd
from scipy.stats import linregress    
df = pd.read_csv('test dataset faceted small.csv')
grouped = df.groupby(['year','make','engine','drive','transmission','badge'])
test = grouped.apply(lambda x: linregress(x['km'], x['price']))
print test
test.to_csv('grouped.csv', index=False)

打印测试给了我:

year  make    engine  drive  transmission  badge                
1994  subaru  1.6L    awd    auto          wrx                      (-0.0019029525668, 2217.67284738, -0.190381626...
1997  mazda   1.3L    2wd    manual        121 metro                (-0.00724142957301, 4213.71579612, -0.30608491...
1999  nissan  1.6L    2wd    auto          pulsar plus lx n15 s2    (-0.00245336355614, 3653.42015515, -0.17060101...

保存到csv的测试是:

LinregressResult(slope=-0.0019029525667976811, intercept=2217.6728473825792, rvalue=-0.19038162624636565, pvalue=4.2750387135904842e-07, stderr=0.00037275167083276965)
LinregressResult(slope=-0.0072414295730094738, intercept=4213.7157961188113, rvalue=-0.30608491681348643, pvalue=4.8781453623746113e-17, stderr=0.00084171437048465665)
LinregressResult(slope=-0.0024533635561369252, intercept=3653.4201551461483, rvalue=-0.17060101350197393, pvalue=1.4676330869804576e-07, stderr=0.0004631573671617427)

但是我想要的csv输出是:

year  make    engine  drive  transmission  badge                   slope             intercept       rvalue      
1994  subaru  1.6L    awd    auto          wrx                     -0.0019029525668  2217.67284738 -0.190381626...
1997  mazda   1.3L    2wd    manual        121 metro               -0.00724142957301 4213.71579612 -0.30608491...
1999  nissan  1.6L    2wd    auto          pulsar plus lx n15 s2   -0.00245336355614 3653.42015515 -0.17060101...

这样我以后就可以轻松调用结果了。如何将LinregressResult附加到每个组并将它们保存到csv?

2 个答案:

答案 0 :(得分:2)

我猜你可以这么做:

test = (grouped.apply(lambda x: pd.Series(linregress(x['km'], x['price'])))
               .rename(columns={
                        0: 'slope',
                        1: 'intercept',
                        2: 'rvalue',
                        3: 'pvalue',
                        4: 'stderr'
                      })
       )

而不是

test = grouped.apply(lambda x: linregress(x['km'], x['price']))

演示:

rows = 10

# generate random integer numbers
df = pd.DataFrame(np.random.randint(0, 10, size=(rows, 5)), columns=list('abcde'))

def linregress(x):
    # imitates `linregress`
    # returns tuples 
    return tuple(x)

test = (df.apply(lambda x: pd.Series(linregress(x)), axis=1)
          .rename(columns={
                   0: 'slope',
                   1: 'intercept',
                   2: 'rvalue',
                   3: 'pvalue',
                   4: 'stderr'
                 })
       )

输出:

In [48]: df.apply(lambda x: linregress(x), axis=1)
Out[48]:
0    (7, 7, 2, 0, 0)
1    (6, 9, 3, 1, 5)
2    (5, 1, 6, 1, 3)
3    (4, 4, 2, 1, 4)
4    (8, 7, 1, 5, 4)
5    (0, 2, 7, 6, 1)
6    (3, 8, 4, 2, 8)
7    (6, 0, 0, 3, 2)
8    (9, 4, 6, 2, 3)
9    (8, 1, 7, 9, 8)
dtype: object


In [50]: test = (df.apply(lambda x: pd.Series(linregress(x)), axis=1)
   ....:           .rename(columns={
   ....:                    0: 'slope',
   ....:                    1: 'intercept',
   ....:                    2: 'rvalue',
   ....:                    3: 'pvalue',
   ....:                    4: 'stderr'
   ....:                  })
   ....:        )

In [51]: test
Out[51]:
   slope  intercept  rvalue  pvalue  stderr
0      7          7       2       0       0
1      6          9       3       1       5
2      5          1       6       1       3
3      4          4       2       1       4
4      8          7       1       5       4
5      0          2       7       6       1
6      3          8       4       2       8
7      6          0       0       3       2
8      9          4       6       2       3
9      8          1       7       9       8

答案 1 :(得分:0)

解决方案

def extract_lr(x): lr = linregress(x['km'], x['price']) return pd.Series([lr.slope, lr.intercept, lr.rvalue], index=['slope', 'intercept', 'rvalue']) test = grouped.apply(lambda x: linregress(x['km'], x['price']))

中使用此功能
<script type="text/javascript" src="../js/jquery-1.12.3.min.js"></script>
<script type="text/javascript">
    $(document).ready(function () {
        $('li').hover(function () {
            $(this).find('ul>li').stop().fadeToggle(200);
        });
    });
</script>


<style type="text/css">

#nav{
    background-color: #282828;
    height: 20px;
    width: 100%;
    float: left;

}

ul{
    margin: auto;
    padding: 0;
    list-style: none;
    display: table;
}

ul li {
    float: left;
    height: 20px;
    line-height: 20px;
    text-align: center;


}

ul li a{
    text-decoration: none;
    color: darkgray;
    padding: 0 10px;

}

ul li li{
    background-color: #282828;
    display: none;
}

ul li ul li{
    width: 100%;

}


ul li:hover{
    background-color: red;
}

</style>


<div id="nav" class="click-nav">

    <ul class="ul">
        <li><a href="#">Home</a></li>
        <li><a href="#">Command Center</a></li>
        <li><a href="#">Stats</a></li>
        <li><a href="#">Community</a>
            <ul>
                <li><a href="#">Wiki</a></li> <br>
                <li><a href="#">Forum</a></li> <br>
                <li><a href="#">Facebook</a></li> <br>
            </ul>
            </li>
        <li><a href="#">Updates</a></li>
        <li><a href="#">About</a></li>
        <li><a href="#">Support</a></li>
    </ul>

</div>