@pandas_udf - 线性回归

时间:2021-01-29 18:59:49

标签: python pandas apache-spark pyspark

有人可以帮我吗?

我需要两组进行线性回归

示例:

pdf = pd.DataFrame({'group_id':[1,1,1,2,2,2,3,3,3,3],
                    'sex':['M','M','F','F','M','F','M','F','F','M'],
                    'x':[0,1,2,0,1,5,2,3,4,5],
                    'y':[2,1,0,0,0.5,2.5,3,4,5,6]})
df = sqlContext.createDataFrame(pdf)


result_schema =StructType([
  StructField('group_id',DoubleType()),
  StructField('sex',StringType()),
  StructField('x',DoubleType())
 ])

@pandas_udf(result_schema, PandasUDFType.GROUPED_MAP)

def ols(df):
    group_id = df['group_id'].iloc[0]
    sex = df['sex'].iloc[0]
    y = df['y'].astype(int)
    X = df['x'].astype(int)
    X = sm.add_constant(X)
    model = sm.OLS(y, X).fit()
    
    
    return pd.DataFrame([[group_id] + [sex] + [model.params[1]]], columns=['group_id'] + ['sex'] + ['x'])


beta = df.groupby('group_id', 'sex').apply(ols)
beta.show()

返回错误:

<块引用>

PythonException:从 UDF 抛出异常:'IndexError: 索引越界',

2 个答案:

答案 0 :(得分:2)

看数据

df = pd.DataFrame({'group_id':[1,1,1,2,2,2,3,3,3,3],
                    'sex':['M','M','F','F','M','F','M','F','F','M'],
                    'x':[0,1,2,0,1,5,2,3,4,5],
                    'y':[2,1,0,0,0.5,2.5,3,4,5,6]})

简单地看这个组

for name, sdf in df.groupby(['group_id', 'sex']):
    print(name)
    print(sdf)

我们得到

(1, 'F')
   group_id sex  x    y
2         1   F  2  0.0
(1, 'M')
   group_id sex  x    y
0         1   M  0  2.0
1         1   M  1  1.0
(2, 'F')
   group_id sex  x    y
3         2   F  0  0.0
5         2   F  5  2.5
(2, 'M')
   group_id sex  x    y
4         2   M  1  0.5
(3, 'F')
   group_id sex  x    y
7         3   F  3  4.0
8         3   F  4  5.0
(3, 'M')
   group_id sex  x    y
6         3   M  2  3.0
9         3   M  5  6.0

现在两个点将获得线性(完美)拟合,您需要三个点才能从潜在的完美拟合中获得某种偏差。其中一些组只有一个数据点,这意味着您无法拟合数据...

答案 1 :(得分:0)

有些组只有一个点,因此无法进行线性回归。要在函数中捕捉到这一点,您可以检查数据框中的行数,如果只有一行,则返回 null

@pandas_udf(result_schema, PandasUDFType.GROUPED_MAP)
def ols(df):
    group_id = df['group_id'].iloc[0]
    sex = df['sex'].iloc[0]

    if len(df) == 1:
        return pd.DataFrame([[group_id] + [sex] + [None]], columns=['group_id'] + ['sex'] + ['x'])

    else:        
        y = df['y'].astype(int)
        X = df['x'].astype(int)
        X = sm.add_constant(X)
        model = sm.OLS(y, X).fit()
        return pd.DataFrame([[group_id] + [sex] + [model.params[1]]], columns=['group_id'] + ['sex'] + ['x'])


df.groupby('group_id', 'sex').apply(ols).show()
+--------+---+-------------------+
|group_id|sex|                  x|
+--------+---+-------------------+
|     2.0|  M|               null|
|     3.0|  F|                1.0|
|     1.0|  M|-1.0000000000000002|
|     1.0|  F|               null|
|     2.0|  F|0.39999999999999986|
|     3.0|  M| 0.9999999999999998|
+--------+---+-------------------+