所以我必须基于11个输入来建立回归模型来预测葡萄酒质量。目前,我正在评估各种算法的均方误差,均值绝对误差和R2分数。我想决定要使用哪种算法,但是在我想做之前,我想确保我的数据不被过度拟合/拟合不足。下面是我使用的数据集的链接(虽然有点不同,但数据完全相同)以及我的整个代码。
任何帮助将不胜感激!
数据: https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/
此外,我在kagggle链接中从以下位置复制了大部分代码: https://www.kaggle.com/jhansia/regression-models-analysis-on-the-wine-quality
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
wine = pd.read_csv('wineQualityReds.csv', usecols=lambda x: 'Unnamed' not in x,)
wine.head()
y = wine.quality
X = wine.drop('quality',axis = 1)
from sklearn.model_selection import train_test_split
train_x,test_x,train_y,test_y = train_test_split(X,y,random_state = 0, stratify = y)
from sklearn import preprocessing
scaler = preprocessing.StandardScaler().fit(train_x)
train_x_scaled = scaler.transform(train_x)
test_x_scaled = scaler.transform(test_x)
from sklearn import model_selection
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.svm import SVR
from sklearn.metrics import mean_absolute_error
models = []
models.append(('DecisionTree', DecisionTreeRegressor()))
models.append(('RandomForest', RandomForestRegressor()))
models.append(('GradienBoost', GradientBoostingRegressor()))
models.append(('SVR', SVR()))
names = []
for name,model in models:
kfold = model_selection.KFold(n_splits=5,random_state=2)
cv_results = model_selection.cross_val_score(model,train_x_scaled,train_y, cv= kfold, scoring = 'neg_mean_absolute_error')
names.append(name)
msg = "%s: %f" % (name, -1*(cv_results).mean())
print(msg)
model = RandomForestRegressor()
model.fit(train_x_scaled,train_y)
pred_y = model.predict(test_x_scaled)
from sklearn import metrics
print('Mean Squared Error:', metrics.mean_squared_error(test_y, pred_y))
print('Mean Absolute Error:', metrics.mean_absolute_error(test_y, pred_y))
print('Root Mean Squared Error:', np.sqrt(metrics.mean_squared_error(test_y, pred_y)))
print('R2:', metrics.r2_score(test_y, pred_y))
答案 0 :(得分:0)
您可以对数据集进行交叉验证,以发现数据是否过度拟合。