有人可以帮我吗?
我需要两组进行线性回归
示例:
pdf = pd.DataFrame({'group_id':[1,1,1,2,2,2,3,3,3,3],
'sex':['M','M','F','F','M','F','M','F','F','M'],
'x':[0,1,2,0,1,5,2,3,4,5],
'y':[2,1,0,0,0.5,2.5,3,4,5,6]})
df = sqlContext.createDataFrame(pdf)
result_schema =StructType([
StructField('group_id',DoubleType()),
StructField('sex',StringType()),
StructField('x',DoubleType())
])
@pandas_udf(result_schema, PandasUDFType.GROUPED_MAP)
def ols(df):
group_id = df['group_id'].iloc[0]
sex = df['sex'].iloc[0]
y = df['y'].astype(int)
X = df['x'].astype(int)
X = sm.add_constant(X)
model = sm.OLS(y, X).fit()
return pd.DataFrame([[group_id] + [sex] + [model.params[1]]], columns=['group_id'] + ['sex'] + ['x'])
beta = df.groupby('group_id', 'sex').apply(ols)
beta.show()
返回错误:
<块引用>PythonException:从 UDF 抛出异常:'IndexError: 索引越界',
答案 0 :(得分:2)
看数据
df = pd.DataFrame({'group_id':[1,1,1,2,2,2,3,3,3,3],
'sex':['M','M','F','F','M','F','M','F','F','M'],
'x':[0,1,2,0,1,5,2,3,4,5],
'y':[2,1,0,0,0.5,2.5,3,4,5,6]})
简单地看这个组
for name, sdf in df.groupby(['group_id', 'sex']):
print(name)
print(sdf)
我们得到
(1, 'F')
group_id sex x y
2 1 F 2 0.0
(1, 'M')
group_id sex x y
0 1 M 0 2.0
1 1 M 1 1.0
(2, 'F')
group_id sex x y
3 2 F 0 0.0
5 2 F 5 2.5
(2, 'M')
group_id sex x y
4 2 M 1 0.5
(3, 'F')
group_id sex x y
7 3 F 3 4.0
8 3 F 4 5.0
(3, 'M')
group_id sex x y
6 3 M 2 3.0
9 3 M 5 6.0
现在两个点将获得线性(完美)拟合,您需要三个点才能从潜在的完美拟合中获得某种偏差。其中一些组只有一个数据点,这意味着您无法拟合数据...
答案 1 :(得分:0)
有些组只有一个点,因此无法进行线性回归。要在函数中捕捉到这一点,您可以检查数据框中的行数,如果只有一行,则返回 null
。
@pandas_udf(result_schema, PandasUDFType.GROUPED_MAP)
def ols(df):
group_id = df['group_id'].iloc[0]
sex = df['sex'].iloc[0]
if len(df) == 1:
return pd.DataFrame([[group_id] + [sex] + [None]], columns=['group_id'] + ['sex'] + ['x'])
else:
y = df['y'].astype(int)
X = df['x'].astype(int)
X = sm.add_constant(X)
model = sm.OLS(y, X).fit()
return pd.DataFrame([[group_id] + [sex] + [model.params[1]]], columns=['group_id'] + ['sex'] + ['x'])
df.groupby('group_id', 'sex').apply(ols).show()
+--------+---+-------------------+
|group_id|sex| x|
+--------+---+-------------------+
| 2.0| M| null|
| 3.0| F| 1.0|
| 1.0| M|-1.0000000000000002|
| 1.0| F| null|
| 2.0| F|0.39999999999999986|
| 3.0| M| 0.9999999999999998|
+--------+---+-------------------+