我试图在不使用sklearn库的情况下构建回归树。 这就是我试图完成任务的方式。
df_train = pd.read_csv('user/assignment2/data/housing_price_train.csv')
df = df_train[['OverallQual','GrLivArea','GarageArea','TotalBsmtSF','1stFlrSF','FullBath','TotRmsAbvGrd','YearBuilt','SalePrice']]
st_sale = np.std(df.SalePrice)
over, less = [],[]
std_r = {}
for c in median:
over = df.SalePrice.loc[df[c] > df[c].median()]
less = df.SalePrice.loc[df[c] <= df[c].median()]
std_over = np.std(over)
std_less = np.std(less)
p_over = over.count()/len(df)
p_less = less.count()/len(df)
result = p_over*std_over + p_less*std_less
result = st_sale - result
std_r[c] = result
del std_r['SalePrice']
std_r = sorted(std_r.items(), key=lambda x: x[1],reverse=True)
std_r
test = pd.read_csv('user/assignment2/data/housing_price_test.csv')
test= test[['OverallQual','GrLivArea','GarageArea','TotalBsmtSF','1stFlrSF','FullBath', 'TotRmsAbvGrd', 'YearBuilt']]
def predict_price(row):
df_part = df
for split_cond in std_r:
col = split_cond[0]
if row[col] < np.median(df_part[col]):
if not df_part.empty:
df_part = df_part.loc[ df_part[col]<np.median(df_part[col]) ]
else:
if not df_part.empty:
df_part = df_part.loc[ df_part[col]>=np.median(df_part[col]) ]
return np.mean(df_part['SalePrice'])
def predict_all(test_df):
ids = []
predictions = []
for i in range(len(test)):
row = test.iloc[i,:]
predicted_price = predict_price(row)
ids.append(i)
predictions.append(predicted_price)
return ids, predictions
def createSubmission(test_ids, predictions):
sub = pd.DataFrame()
sub['Id'] = test_ids
sub['SalePrice'] = predictions
sub.to_csv('user/assignment2/submission.csv',index=False)
def main():
ids, predictions = predict_all(test)
createSubmission(ids, predictions)
if __name__ == '__main__':
main()
std_r的输出为[('OverallQual',24169.639457317156), ('GrLivArea',18442.815198341486),('YearBuilt',14193.29356392668), ('GarageArea',13759.676934338233),('1stFlrSF',12427.210763854717), ('TotalBsmtSF',12327.533408498653), ('TotRmsAbvGrd',11091.35232613662),('FullBath',5381.883704447857)]
基于此,我试图降低df并最终取剩余销售价格的均值。这给我最终的CSV文件中的一些空值,这没有发生