提高 Python 中机器学习模型预测的准确性

时间:2021-01-03 18:38:03

标签: python machine-learning regression

我们目前正在 Python 中为一家本地公司实施 ML 模型,以预测 0-999 分范围内的信用评分。从数据库中提取了 11 个自变量(信用记录和支付行为)和一个因变量(信用评分)。客户已声明生产模型的 MAE 必须小于 100 分才能有用。问题是我们已经尝试了几种算法来实现这种回归,但我们的模型无法很好地概括不可见的数据。到目前为止,性能最好的算法似乎是随机森林,但它在测试数据上的 MAE 仍然超出了可接受的值。这是我们的代码:

import numpy as np
from sklearn.preprocessing import RobustScaler
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Ridge
from sklearn.linear_model import Lasso
from sklearn.linear_model import ElasticNet
from sklearn.ensemble import RandomForestRegressor
from sklearn import metrics
from keras.layers import Dense
from keras.models import Sequential

# Linear Model
def GetLinearModel(X, y):
    model = LinearRegression()
    model.fit(X, y)
    return model   

# Ridge Regression
def GetRidge(X, y):
    model = Ridge(alpha=0.01)
    model.fit(X_train, y_train) 
    return model

# LASSO Regression
def GetLASSO(X, y):
    model = Lasso(alpha=0.01)
    model.fit(X_train, y_train) 
    return model

# ElasticNet Regression
def GetElasticNet(X, y):
    model = ElasticNet(alpha=0.01)
    model.fit(X_train, y_train) 
    return model

# Random Forest
def GetRandomForest(X, y):
    model = RandomForestRegressor(n_estimators=32, random_state=0)
    model.fit(X, y)
    return model

# Neural Networks
def GetNeuralNetworks(X, y):
    model = Sequential()
    model.add(Dense(32, activation = 'relu', input_dim = 11))
    model.add(Dense(units = 32, activation = 'relu'))
    model.add(Dense(units = 32, activation = 'relu'))
    model.add(Dense(units = 32, activation = 'relu'))
    model.add(Dense(units = 32, activation = 'relu'))
    model.add(Dense(units = 1))
    model.compile(optimizer = 'adam', loss = 'mean_absolute_error')
    model.fit(X, y, batch_size = 100, epochs = 500, verbose=0)
    return model

# Train data
train_set = np.array([\
[2, 5, 9, 28, 0, 0.153668, 500, 0, 0, 0.076923077, 0, 800],\
[3, 0, 0, 42, 2, 0.358913, 500, 0, 0, 0.230769231, 0, 900],\
[3, 0, 0, 12, 2, 0, 500, 0, 0, 0.076923077, 0, 500],\
[1, 0, 0, 6, 1, 0.340075, 457, 0, 0, 0.076923077, 0, 560],\
[1, 5, 0, 12, 3, 0.458358, 457, 0, 0, 0.153846154, 0, 500],\
[1, 3, 4, 32, 2, 0.460336, 457, 0, 0, 0.153846154, 0, 600],\
[3, 0, 0, 42, 4, 0.473414, 500, 0, 0, 0.230769231, 0, 700],\
[1, 3, 0, 16, 0, 0.332991, 500, 0, 0, 0.076923077, 0, 600],\
[1, 3, 19, 27, 0, 0.3477, 500, 0, 0, 0.076923077, 0, 580],\
[1, 5, 20, 74, 1, 0.52076, 500, 0, 0, 0.230769231, 0, 550],\
[6, 0, 0, 9, 3, 0, 500, 0, 0, 0.076923077, 0, 570],\
[1, 8, 47, 0, 0, 0.840656, 681, 0, 0, 0, 0, 50],\
[1, 0, 0, 8, 14, 0, 681, 0, 0, 0.076923077, 0, 400],\
[5, 6, 19, 7, 1, 0.251423, 500, 0, 1, 0.076923077, 1, 980],\
[1, 0, 0, 2, 2, 0.121852, 500, 1, 0, 0.076923077, 9, 780],\
[2, 0, 0, 4, 0, 0.37242, 500, 1, 0, 0.076923077, 0, 920],\
[3, 4, 5, 20, 0, 0.37682, 500, 1, 0, 0.076923077, 0, 700],\
[3, 8, 17, 20, 0, 0.449545, 500, 1, 0, 0.076923077, 0, 300],\
[3, 12, 30, 20, 0, 0.551193, 500, 1, 0, 0.076923077, 0, 30],\
[0, 1, 10, 8, 3, 0.044175, 500, 0, 0, 0.076923077, 0, 350],\
[1, 0, 0, 14, 3, 0.521714, 500, 0, 0, 0.153846154, 0, 650],\
[2, 4, 15, 0, 0, 0.985122, 500, 0, 0, 0, 0, 550],\
[2, 4, 34, 0, 0, 0.666666, 500, 0, 0, 0, 0, 600],\
[1, 16, 17, 10, 3, 0.299756, 330, 0, 0, 0.153846154, 0, 650],\
[2, 0, 0, 16, 1, 0, 500, 0, 0, 0.076923077, 0, 900],\
[2, 5, 31, 26, 0, 0.104847, 500, 0, 0, 0.076923077, 0, 850],\
[2, 6, 16, 34, 1, 0.172947, 500, 0, 0, 0.153846154, 0, 900],\
[1, 4, 0, 16, 6, 0.206403, 500, 0, 0, 0.153846154, 0, 630],\
[1, 8, 20, 12, 5, 0.495897, 500, 0, 0, 0.153846154, 0, 500],\
[1, 8, 46, 8, 6, 0.495897, 500, 0, 0, 0.153846154, 0, 250],\
[2, 0, 0, 4, 8, 0, 500, 0, 0, 0.076923077, 0, 550],\
[2, 6, 602, 0, 0, 0, 500, 0, 0, 0, 0, 20],\
[0, 12, 5, 21, 0, 0.158674, 645, 0, 0, 0.153846154, 0, 850],\
[0, 12, 20, 21, 0, 0.158674, 645, 0, 0, 0.153846154, 0, 700],\
[1, 0, 0, 33, 0, 0.041473, 645, 0, 0, 0.230769231, 0, 890],\
[1, 0, 0, 12, 2, 0.147325, 500, 0, 0, 0.076923077, 0, 780],\
[1, 8, 296, 0, 0, 2.891695, 521, 0, 0, 0, 0, 1],\
[1, 0, 0, 4, 0, 0.098953, 445, 0, 0, 0.076923077, 0, 600],\
[1, 0, 0, 4, 0, 0.143443, 500, 0, 0, 0.076923077, 0, 500],\
[0, 8, 20, 0, 0, 1.110002, 833, 0, 0, 0, 0, 100],\
[0, 0, 0, 8, 2, 0, 833, 0, 0, 0.076923077, 0, 300],\
[1, 4, 60, 20, 6, 0.78685, 833, 0, 0, 0.153846154, 0, 100],\
[1, 4, 112, 20, 6, 0.78685, 833, 0, 0, 0.153846154, 0, 1],\
[1, 0, 0, 21, 10, 0.305556, 500, 0, 0, 0.307692308, 0, 150],\
[1, 0, 0, 21, 10, 0.453743, 500, 0, 0, 0.307692308, 0, 300],\
[0, 0, 0, 8, 0, 0, 570, 0, 0, 0, 0, 500],\
[0, 10, 10, 8, 0, 0.325975, 570, 0, 0, 0.076923077, 0, 450],\
[1, 7, 16, 15, 1, 0.266311, 570, 0, 0, 0.076923077, 0, 450],\
[1, 1, 32, 30, 4, 0.134606, 570, 0, 0, 0.230769231, 0, 250],\
[1, 0, 0, 32, 5, 0.105576, 570, 0, 0, 0.230769231, 0, 430],\
[1, 4, 34, 32, 5, 0.519103, 500, 0, 0, 0.230769231, 0, 350],\
[1, 0, 0, 12, 1, 0.109559, 669, 0, 0, 0.076923077, 0, 600],\
[11, 4, 15, 2, 3, 0.235709, 500, 0, 1, 0, 2, 900],\
[11, 4, 15, 1, 6, 0.504134, 500, 0, 1, 0, 2, 534],\
[2, 0, 0, 15, 9, 0.075403, 500, 0, 0, 0.076923077, 0, 573],\
[10, 0, 0, 51, 11, 2.211951, 500, 0, 0, 0.307692308, 7, 547],\
[9, 0, 0, 28, 4, 0.328037, 500, 0, 0, 0.230769231, 0, 747],\
[9, 2, 0, 0, 0, 0.166666, 500, 0, 1, 0.076923077, 4, 448],\
[8, 0, 0, 4, 1, 0, 500, 0, 1, 0, 1, 719],\
[3, 4, 15, 8, 1, 0.150237, 500, 0, 1, 0, 0, 827],\
[7, 138, 35, 37, 1, 0.414154, 500, 0, 1, 0.076923077, 3, 950],\
[6, 19, 41, 84, 1, 0.41248, 500, 0, 0, 0.230769231, 0, 750],\
[1, 6, 10, 0, 0, 0.232647, 500, 0, 1, 0, 0, 700],\
[0, 10, 27, 0, 0, 0.411712, 4, 0, 0, 0, 0, 520],\
[3, 31, 45, 80, 0, 0.266299, 500, 0, 0, 0.153846154, 0, 750],\
[3, 24, 49, 2, 1, 0.981102, 500, 0, 0, 0.076923077, 0, 550],\
[1, 12, 31, 11, 1, 0.333551, 500, 0, 0, 0.153846154, 0, 500],\
[0, 18, 30, 13, 2, 0.602826, 406, 0, 0, 0.076923077, 0, 580],\
[2, 2, 31, 0, 0, 1, 500, 0, 0, 0, 0, 427],\
[1, 18, 40, 83, 1, 0.332792, 500, 0, 0, 0.307692308, 0, 485],\
[2, 14, 35, 9, 3, 0.39671, 500, 0, 1, 0.076923077, 3, 664],\
[2, 88, 32, 7, 2, 0.548066, 500, 0, 1, 0, 1, 90],\
[2, 26, 26, 32, 2, 0.415991, 500, 0, 0, 0.153846154, 0, 90],\
[1, 14, 30, 11, 1, 0.51743, 599, 0, 0, 0.153846154, 0, 300],\
[1, 15, 28, 26, 0, 0.4413, 500, 0, 0, 0.076923077, 0, 610],\
[1, 17, 50, 34, 1, 0.313789, 500, 0, 0, 0.230769231, 0, 450],\
[0, 4, 15, 0, 0, 0.535163, 500, 0, 0, 0, 0, 375],\
[0, 8, 23, 0, 0, 0.51242, 500, 0, 0, 0, 0, 550],\
[3, 6, 44, 2, 3, 0.268062, 500, 0, 1, 0, 2, 744],\
[6, 38, 51, 35, 0, 0.28396, 500, 0, 1, 0.076923077, 1, 980],\
[6, 5, 63, 6, 5, 0.566661, 500, 0, 0, 0.153846154, 0, 850],\
[6, 0, 0, 0, 0, 0.174852, 500, 0, 0, 0, 0, 800],\
[6, 4, 60, 6, 3, 0.517482, 500, 0, 0, 0.076923077, 0, 750],\
[5, 16, 52, 49, 4, 0.378441, 500, 0, 1, 0.153846154, 6, 720],\
[5, 26, 84, 103, 1, 0.472361, 500, 0, 0, 0.230769231, 0, 300],\
[1, 6, 34, 36, 1, 0.298553, 500, 0, 1, 0.153846154, 0, 628],\
[5, 6, 65, 34, 0, 0.301907, 500, 0, 0, 0.153846154, 0, 710],\
[3, 16, 177, 29, 10, 0.501831, 500, 1, 0, 0.153846154, 0, 40],\
[2, 5, 45, 0, 0, 0.351668, 500, 0, 0, 0, 0, 708],\
[2, 7, 57, 7, 4, 0.432374, 500, 0, 0, 0.153846154, 0, 753],\
[1, 1, 75, 36, 0, 0.154085, 500, 0, 0, 0.076923077, 0, 610],\
[1, 16, 63, 13, 2, 0.331244, 500, 0, 0, 0.076923077, 0, 620],\
[1, 3, 55, 9, 0, 0.377253, 500, 0, 0, 0.076923077, 0, 640],\
[1, 1, 75, 5, 5, 0.877696, 500, 0, 0, 0.076923077, 0, 480],\
[1, 0, 0, 8, 5, 0.208742, 500, 0, 0, 0.153846154, 0, 520],\
[1, 3, 55, 29, 0, 0.228812, 678, 0, 0, 0.153846154, 0, 547],\
[1, 0, 0, 2, 2, 0.090459, 553, 0, 0, 0.076923077, 0, 535],\
[0, 4, 29, 0, 0, 0.292161, 500, 0, 0, 0, 0, 594],\
[1, 3, 64, 18, 6, 0.602431, 500, 0, 0, 0.230769231, 0, 500],\
[6, 9, 40, 74, 0, 0.567179, 500, 0, 0, 0.076923077, 0, 910],\
[4, 10, 65, 14, 1, 0.423915, 500, 0, 1, 0, 1, 713],\
[2, 0, 0, 6, 1, 0.114637, 500, 0, 0, 0.076923077, 0, 650],\
[5, 18, 74, 34, 0, 0.489314, 500, 0, 0, 0.153846154, 0, 500],\
[0, 6, 43, 9, 15, 0.599918, 612, 0, 0, 0.153846154, 0, 100],\
[4, 25, 64, 135, 0, 0.472659, 500, 0, 0, 0.230769231, 0, 560],\
[6, 3, 94, 12, 10, 0.31713, 500, 0, 0, 0.230769231, 0, 580],\
[1, 4, 69, 18, 9, 0.412528, 500, 0, 0, 0.307692308, 0, 362],\
[2, 21, 58, 21, 0, 0.53184, 500, 0, 0, 0.153846154, 0, 370],\
[0, 0, 0, 21, 4, 0.033438, 500, 0, 0, 0.153846154, 0, 500],\
[0, 10, 53, 20, 0, 0.619595, 500, 0, 0, 0.076923077, 0, 200],\
[2, 15, 63, 28, 2, 0.593453, 500, 0, 0, 0.153846154, 0, 574],\
[3, 2, 84, 21, 1, 0.302636, 500, 0, 0, 0.153846154, 0, 790],\
[4, 19, 47, 28, 0, 0.256892, 500, 0, 0, 0.076923077, 0, 748],\
[1, 0, 0, 0, 0, 0.119599, 500, 0, 0, 0, 0, 517],\
[3, 10, 53, 22, 0, 0.419703, 500, 0, 0, 0.153846154, 0, 800],\
[4, 7, 66, 70, 1, 0.362268, 500, 0, 0, 0.230769231, 0, 550],\
[0, 16, 88, 18, 3, 0.597145, 16, 0, 0, 0.153846154, 0, 50],\
[5, 8, 38, 0, 0, 0.666666, 500, 0, 0, 0, 0, 667]])

# Test data    
test_set = np.array([\
[2, 16, 87, 30, 0, 0.168057, 500, 0, 1, 0.153846154, 1, 760],\
[3, 5, 83, 6, 4, 0.273522, 500, 0, 0, 0.076923077, 0, 877],\
[1, 0, 0, 12, 0, 0.262797, 500, 0, 0, 0.153846154, 0, 596],\
[2, 15, 46, 28, 0, 0.495495, 500, 0, 0, 0.076923077, 0, 680],\
[1, 0, 0, 22, 9, 0.254813, 500, 0, 0, 0.230769231, 0, 450],\
[3, 19, 59, 12, 0, 0.437851, 500, 0, 0, 0.153846154, 0, 850],\
[4, 5, 28, 0, 0, 0.34559, 500, 0, 1, 0.076923077, 1, 800],\
[1, 5, 58, 0, 0, 0.385379, 500, 0, 0, 0, 0, 641],\
[1, 4, 65, 15, 1, 0.2945, 500, 0, 0, 0.153846154, 0, 644],\
[0, 0, 0, 9, 3, 0.421612, 500, 0, 0, 0.076923077, 0, 580],\
[3, 31, 83, 2, 2, 0.436883, 500, 0, 0, 0.076923077, 0, 410],\
[0, 0, 0, 18, 5, 0.044898, 377, 0, 0, 0.230769231, 0, 520],\
[0, 8, 49, 12, 3, 0.428529, 500, 0, 1, 0.076923077, 1, 370],\
[0, 22, 89, 2, 1, 0.819431, 500, 0, 0, 0.076923077, 0, 440],\
[3, 27, 63, 124, 0, 0.375306, 500, 0, 0, 0.076923077, 0, 880],\
[3, 20, 64, 18, 5, 0.439412, 500, 0, 1, 0.076923077, 3, 820],\
[1, 6, 34, 2, 12, 0.495654, 500, 0, 0, 0.076923077, 0, 653],\
[0, 14, 225, 0, 0, 1, 486, 0, 0, 0, 0, 1],\
[2, 8, 87, 32, 1, 0.829792, 500, 0, 0, 0.230769231, 0, 570],\
[2, 15, 46, 24, 4, 0.500442, 500, 0, 0, 0.153846154, 0, 568]])

# split datasets into independent and dependent variables
X_train, y_train = train_set[:, :-1], train_set[:, -1]    
X_test, y_test = test_set[:, :-1], test_set[:, -1]    

# feature scaling
sc = RobustScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.fit_transform(X_test)

# Linear model
reg = GetLinearModel(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("Linear", mae))

# Ridge Regression
reg = GetRidge(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("Ridge", mae))

# LASSO Regression
reg = GetLASSO(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("LASSO", mae))

# ElasticNet Regression
reg = GetElasticNet(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("ElasticNet", mae))

# Random Forest
reg = GetRandomForest(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("Random Forest", mae))

# Neural networks
reg = GetNeuralNetworks(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("Neural Networks", mae))

输出:

         Linear: 141.265089
          Ridge: 141.267797
          LASSO: 141.274700
     ElasticNet: 141.413544
  Random Forest: 102.701562
WARNING:tensorflow:11 out of the last 11 calls to <function Model.make_predict_function.<locals>.predict_function at 0x00000229766694C0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Neural Networks: 122.301840

任何有关如何提高模型准确性的帮助将不胜感激。

亲切的问候。

4 个答案:

答案 0 :(得分:1)

我正在使用您在示例中提供的数据集 我还创建了训练、验证和测试数据集,以避免@Prayson W. Daniel 提到的数据泄漏

对于神经网络,您需要确保标签和特征都被缩放。您可以选择标准标量。您还需要确保特征和标签必须为 2 暗。在您的示例中,您的标签是一维数组。

使用以下代码提取二维特征

Train_labels=train_set[:,[-1]]

你可以使用StandardScaler对数据进行归一化,你需要确保标签和特征都需要归一化

现在,一旦您构建了 ANN,您就需要确保您的网络能够看到大量数据 由于您的训练和测试非常少,您可以使用 K 折交叉验证 我现在不使用 k 折叠,但我正在创建模型

from keras import regularizers
def build_model() :
    Model=K.models.Sequential()
    Model.add(K.layers.Dense(units=21,activation='relu',
              kernel_regularizer=regularizers.l2(0.001),input_dim=11))
    Model.add(K.layers.Dropout(0.2))
    Model.add(K.layers.Dense(21,activation='relu',
              kernel_regularizer=regularizers.l2(0.001)))
    Model.add(K.layers.Dropout(0.2))
    Model.add(K.layers.Dense(21,activation='relu'))
    Model.add(K.layers.Dense(1))

    #Compile the model


    Optimizer=K.optimizers.Nadam()
    Model.compile(optimizer=Optimizer,loss='mae',metrics=r2_keras_custom)
    return Model


model=build_model()
history=model.fit(x=X_train,y=Y_train,epochs=200,batch_size=29,validation_data= 
(X_test,Y_test))

I am using R2 as custom metric,you can also create one 

这里我使用的是 1-RSS/TSS 的 r2

plt.plot(history.history['val_r2_keras_custom'])
plt.plot(history.history['r2_keras_custom'])
plt.legend(['Test_score','Train_score'])
plt.plot()

enter image description here

Final score

希望对你有帮助,其他人可以纠正我

答案 1 :(得分:0)

如果那是整个数据集,那就太小了。要考虑的一种选择是研究交叉验证,而不是将数据拆分为训练和验证(AKA 测试)。交叉验证是一种适用于小数据集的方法,其中所有数据都用于训练和验证,但仍可防止过度拟合。

答案 2 :(得分:0)

您可以为每个模型执行超参数调整和交叉验证。

这个类可以帮助你做到这一点:https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html

GridSearchCV 也兼容 Keras 模型。为此,您可以查看: https://machinelearningmastery.com/grid-search-hyperparameters-deep-learning-models-python-keras/

答案 3 :(得分:0)

就个人而言,训练数据集中的少量记录意味着机器学习算法训练集合中的基分类器数量较少。检查您的代码,我之前没有使用过 RobustScaler,但我会在测试数据集上使用转换,而不是 fit_transform。

回到您的代码,看起来随机森林的准确度最高。通过超调一些参数,包括估计器的数量和 max_depth,可以报告更好的性能。此后,正如其他答案/评论所推荐的那样,此处需要对算法参数进行超调。

# -*- coding: utf-8 -*-
"""
Created on Wed Jan  6 20:50:44 2021

@author: AliHaidar
"""

import numpy as np
from sklearn.preprocessing import RobustScaler
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Ridge
from sklearn.linear_model import Lasso
from sklearn.linear_model import ElasticNet
from sklearn.ensemble import RandomForestRegressor,GradientBoostingRegressor,AdaBoostRegressor
from sklearn import metrics

from xgboost import XGBRegressor


# Linear Model
def GetLinearModel(X, y):
    model = LinearRegression()
    model.fit(X, y)
    return model   

# Ridge Regression
def GetRidge(X, y):
    model = Ridge(alpha=0.01)
    model.fit(X_train, y_train) 
    return model

# LASSO Regression
def GetLASSO(X, y):
    model = Lasso(alpha=0.01)
    model.fit(X_train, y_train) 
    return model

# ElasticNet Regression
def GetElasticNet(X, y):
    model = ElasticNet(alpha=0.01)
    model.fit(X_train, y_train) 
    return model

# Random Forest
def GetRandomForest(X, y):
    model = RandomForestRegressor(n_estimators=4, random_state=0,max_depth=11)
    model.fit(X, y)
    return model


# Train data
train_set = np.array([\
[2, 5, 9, 28, 0, 0.153668, 500, 0, 0, 0.076923077, 0, 800],\
[3, 0, 0, 42, 2, 0.358913, 500, 0, 0, 0.230769231, 0, 900],\
[3, 0, 0, 12, 2, 0, 500, 0, 0, 0.076923077, 0, 500],\
[1, 0, 0, 6, 1, 0.340075, 457, 0, 0, 0.076923077, 0, 560],\
[1, 5, 0, 12, 3, 0.458358, 457, 0, 0, 0.153846154, 0, 500],\
[1, 3, 4, 32, 2, 0.460336, 457, 0, 0, 0.153846154, 0, 600],\
[3, 0, 0, 42, 4, 0.473414, 500, 0, 0, 0.230769231, 0, 700],\
[1, 3, 0, 16, 0, 0.332991, 500, 0, 0, 0.076923077, 0, 600],\
[1, 3, 19, 27, 0, 0.3477, 500, 0, 0, 0.076923077, 0, 580],\
[1, 5, 20, 74, 1, 0.52076, 500, 0, 0, 0.230769231, 0, 550],\
[6, 0, 0, 9, 3, 0, 500, 0, 0, 0.076923077, 0, 570],\
[1, 8, 47, 0, 0, 0.840656, 681, 0, 0, 0, 0, 50],\
[1, 0, 0, 8, 14, 0, 681, 0, 0, 0.076923077, 0, 400],\
[5, 6, 19, 7, 1, 0.251423, 500, 0, 1, 0.076923077, 1, 980],\
[1, 0, 0, 2, 2, 0.121852, 500, 1, 0, 0.076923077, 9, 780],\
[2, 0, 0, 4, 0, 0.37242, 500, 1, 0, 0.076923077, 0, 920],\
[3, 4, 5, 20, 0, 0.37682, 500, 1, 0, 0.076923077, 0, 700],\
[3, 8, 17, 20, 0, 0.449545, 500, 1, 0, 0.076923077, 0, 300],\
[3, 12, 30, 20, 0, 0.551193, 500, 1, 0, 0.076923077, 0, 30],\
[0, 1, 10, 8, 3, 0.044175, 500, 0, 0, 0.076923077, 0, 350],\
[1, 0, 0, 14, 3, 0.521714, 500, 0, 0, 0.153846154, 0, 650],\
[2, 4, 15, 0, 0, 0.985122, 500, 0, 0, 0, 0, 550],\
[2, 4, 34, 0, 0, 0.666666, 500, 0, 0, 0, 0, 600],\
[1, 16, 17, 10, 3, 0.299756, 330, 0, 0, 0.153846154, 0, 650],\
[2, 0, 0, 16, 1, 0, 500, 0, 0, 0.076923077, 0, 900],\
[2, 5, 31, 26, 0, 0.104847, 500, 0, 0, 0.076923077, 0, 850],\
[2, 6, 16, 34, 1, 0.172947, 500, 0, 0, 0.153846154, 0, 900],\
[1, 4, 0, 16, 6, 0.206403, 500, 0, 0, 0.153846154, 0, 630],\
[1, 8, 20, 12, 5, 0.495897, 500, 0, 0, 0.153846154, 0, 500],\
[1, 8, 46, 8, 6, 0.495897, 500, 0, 0, 0.153846154, 0, 250],\
[2, 0, 0, 4, 8, 0, 500, 0, 0, 0.076923077, 0, 550],\
[2, 6, 602, 0, 0, 0, 500, 0, 0, 0, 0, 20],\
[0, 12, 5, 21, 0, 0.158674, 645, 0, 0, 0.153846154, 0, 850],\
[0, 12, 20, 21, 0, 0.158674, 645, 0, 0, 0.153846154, 0, 700],\
[1, 0, 0, 33, 0, 0.041473, 645, 0, 0, 0.230769231, 0, 890],\
[1, 0, 0, 12, 2, 0.147325, 500, 0, 0, 0.076923077, 0, 780],\
[1, 8, 296, 0, 0, 2.891695, 521, 0, 0, 0, 0, 1],\
[1, 0, 0, 4, 0, 0.098953, 445, 0, 0, 0.076923077, 0, 600],\
[1, 0, 0, 4, 0, 0.143443, 500, 0, 0, 0.076923077, 0, 500],\
[0, 8, 20, 0, 0, 1.110002, 833, 0, 0, 0, 0, 100],\
[0, 0, 0, 8, 2, 0, 833, 0, 0, 0.076923077, 0, 300],\
[1, 4, 60, 20, 6, 0.78685, 833, 0, 0, 0.153846154, 0, 100],\
[1, 4, 112, 20, 6, 0.78685, 833, 0, 0, 0.153846154, 0, 1],\
[1, 0, 0, 21, 10, 0.305556, 500, 0, 0, 0.307692308, 0, 150],\
[1, 0, 0, 21, 10, 0.453743, 500, 0, 0, 0.307692308, 0, 300],\
[0, 0, 0, 8, 0, 0, 570, 0, 0, 0, 0, 500],\
[0, 10, 10, 8, 0, 0.325975, 570, 0, 0, 0.076923077, 0, 450],\
[1, 7, 16, 15, 1, 0.266311, 570, 0, 0, 0.076923077, 0, 450],\
[1, 1, 32, 30, 4, 0.134606, 570, 0, 0, 0.230769231, 0, 250],\
[1, 0, 0, 32, 5, 0.105576, 570, 0, 0, 0.230769231, 0, 430],\
[1, 4, 34, 32, 5, 0.519103, 500, 0, 0, 0.230769231, 0, 350],\
[1, 0, 0, 12, 1, 0.109559, 669, 0, 0, 0.076923077, 0, 600],\
[11, 4, 15, 2, 3, 0.235709, 500, 0, 1, 0, 2, 900],\
[11, 4, 15, 1, 6, 0.504134, 500, 0, 1, 0, 2, 534],\
[2, 0, 0, 15, 9, 0.075403, 500, 0, 0, 0.076923077, 0, 573],\
[10, 0, 0, 51, 11, 2.211951, 500, 0, 0, 0.307692308, 7, 547],\
[9, 0, 0, 28, 4, 0.328037, 500, 0, 0, 0.230769231, 0, 747],\
[9, 2, 0, 0, 0, 0.166666, 500, 0, 1, 0.076923077, 4, 448],\
[8, 0, 0, 4, 1, 0, 500, 0, 1, 0, 1, 719],\
[3, 4, 15, 8, 1, 0.150237, 500, 0, 1, 0, 0, 827],\
[7, 138, 35, 37, 1, 0.414154, 500, 0, 1, 0.076923077, 3, 950],\
[6, 19, 41, 84, 1, 0.41248, 500, 0, 0, 0.230769231, 0, 750],\
[1, 6, 10, 0, 0, 0.232647, 500, 0, 1, 0, 0, 700],\
[0, 10, 27, 0, 0, 0.411712, 4, 0, 0, 0, 0, 520],\
[3, 31, 45, 80, 0, 0.266299, 500, 0, 0, 0.153846154, 0, 750],\
[3, 24, 49, 2, 1, 0.981102, 500, 0, 0, 0.076923077, 0, 550],\
[1, 12, 31, 11, 1, 0.333551, 500, 0, 0, 0.153846154, 0, 500],\
[0, 18, 30, 13, 2, 0.602826, 406, 0, 0, 0.076923077, 0, 580],\
[2, 2, 31, 0, 0, 1, 500, 0, 0, 0, 0, 427],\
[1, 18, 40, 83, 1, 0.332792, 500, 0, 0, 0.307692308, 0, 485],\
[2, 14, 35, 9, 3, 0.39671, 500, 0, 1, 0.076923077, 3, 664],\
[2, 88, 32, 7, 2, 0.548066, 500, 0, 1, 0, 1, 90],\
[2, 26, 26, 32, 2, 0.415991, 500, 0, 0, 0.153846154, 0, 90],\
[1, 14, 30, 11, 1, 0.51743, 599, 0, 0, 0.153846154, 0, 300],\
[1, 15, 28, 26, 0, 0.4413, 500, 0, 0, 0.076923077, 0, 610],\
[1, 17, 50, 34, 1, 0.313789, 500, 0, 0, 0.230769231, 0, 450],\
[0, 4, 15, 0, 0, 0.535163, 500, 0, 0, 0, 0, 375],\
[0, 8, 23, 0, 0, 0.51242, 500, 0, 0, 0, 0, 550],\
[3, 6, 44, 2, 3, 0.268062, 500, 0, 1, 0, 2, 744],\
[6, 38, 51, 35, 0, 0.28396, 500, 0, 1, 0.076923077, 1, 980],\
[6, 5, 63, 6, 5, 0.566661, 500, 0, 0, 0.153846154, 0, 850],\
[6, 0, 0, 0, 0, 0.174852, 500, 0, 0, 0, 0, 800],\
[6, 4, 60, 6, 3, 0.517482, 500, 0, 0, 0.076923077, 0, 750],\
[5, 16, 52, 49, 4, 0.378441, 500, 0, 1, 0.153846154, 6, 720],\
[5, 26, 84, 103, 1, 0.472361, 500, 0, 0, 0.230769231, 0, 300],\
[1, 6, 34, 36, 1, 0.298553, 500, 0, 1, 0.153846154, 0, 628],\
[5, 6, 65, 34, 0, 0.301907, 500, 0, 0, 0.153846154, 0, 710],\
[3, 16, 177, 29, 10, 0.501831, 500, 1, 0, 0.153846154, 0, 40],\
[2, 5, 45, 0, 0, 0.351668, 500, 0, 0, 0, 0, 708],\
[2, 7, 57, 7, 4, 0.432374, 500, 0, 0, 0.153846154, 0, 753],\
[1, 1, 75, 36, 0, 0.154085, 500, 0, 0, 0.076923077, 0, 610],\
[1, 16, 63, 13, 2, 0.331244, 500, 0, 0, 0.076923077, 0, 620],\
[1, 3, 55, 9, 0, 0.377253, 500, 0, 0, 0.076923077, 0, 640],\
[1, 1, 75, 5, 5, 0.877696, 500, 0, 0, 0.076923077, 0, 480],\
[1, 0, 0, 8, 5, 0.208742, 500, 0, 0, 0.153846154, 0, 520],\
[1, 3, 55, 29, 0, 0.228812, 678, 0, 0, 0.153846154, 0, 547],\
[1, 0, 0, 2, 2, 0.090459, 553, 0, 0, 0.076923077, 0, 535],\
[0, 4, 29, 0, 0, 0.292161, 500, 0, 0, 0, 0, 594],\
[1, 3, 64, 18, 6, 0.602431, 500, 0, 0, 0.230769231, 0, 500],\
[6, 9, 40, 74, 0, 0.567179, 500, 0, 0, 0.076923077, 0, 910],\
[4, 10, 65, 14, 1, 0.423915, 500, 0, 1, 0, 1, 713],\
[2, 0, 0, 6, 1, 0.114637, 500, 0, 0, 0.076923077, 0, 650],\
[5, 18, 74, 34, 0, 0.489314, 500, 0, 0, 0.153846154, 0, 500],\
[0, 6, 43, 9, 15, 0.599918, 612, 0, 0, 0.153846154, 0, 100],\
[4, 25, 64, 135, 0, 0.472659, 500, 0, 0, 0.230769231, 0, 560],\
[6, 3, 94, 12, 10, 0.31713, 500, 0, 0, 0.230769231, 0, 580],\
[1, 4, 69, 18, 9, 0.412528, 500, 0, 0, 0.307692308, 0, 362],\
[2, 21, 58, 21, 0, 0.53184, 500, 0, 0, 0.153846154, 0, 370],\
[0, 0, 0, 21, 4, 0.033438, 500, 0, 0, 0.153846154, 0, 500],\
[0, 10, 53, 20, 0, 0.619595, 500, 0, 0, 0.076923077, 0, 200],\
[2, 15, 63, 28, 2, 0.593453, 500, 0, 0, 0.153846154, 0, 574],\
[3, 2, 84, 21, 1, 0.302636, 500, 0, 0, 0.153846154, 0, 790],\
[4, 19, 47, 28, 0, 0.256892, 500, 0, 0, 0.076923077, 0, 748],\
[1, 0, 0, 0, 0, 0.119599, 500, 0, 0, 0, 0, 517],\
[3, 10, 53, 22, 0, 0.419703, 500, 0, 0, 0.153846154, 0, 800],\
[4, 7, 66, 70, 1, 0.362268, 500, 0, 0, 0.230769231, 0, 550],\
[0, 16, 88, 18, 3, 0.597145, 16, 0, 0, 0.153846154, 0, 50],\
[5, 8, 38, 0, 0, 0.666666, 500, 0, 0, 0, 0, 667]])

# Test data    
test_set = np.array([\
[2, 16, 87, 30, 0, 0.168057, 500, 0, 1, 0.153846154, 1, 760],\
[3, 5, 83, 6, 4, 0.273522, 500, 0, 0, 0.076923077, 0, 877],\
[1, 0, 0, 12, 0, 0.262797, 500, 0, 0, 0.153846154, 0, 596],\
[2, 15, 46, 28, 0, 0.495495, 500, 0, 0, 0.076923077, 0, 680],\
[1, 0, 0, 22, 9, 0.254813, 500, 0, 0, 0.230769231, 0, 450],\
[3, 19, 59, 12, 0, 0.437851, 500, 0, 0, 0.153846154, 0, 850],\
[4, 5, 28, 0, 0, 0.34559, 500, 0, 1, 0.076923077, 1, 800],\
[1, 5, 58, 0, 0, 0.385379, 500, 0, 0, 0, 0, 641],\
[1, 4, 65, 15, 1, 0.2945, 500, 0, 0, 0.153846154, 0, 644],\
[0, 0, 0, 9, 3, 0.421612, 500, 0, 0, 0.076923077, 0, 580],\
[3, 31, 83, 2, 2, 0.436883, 500, 0, 0, 0.076923077, 0, 410],\
[0, 0, 0, 18, 5, 0.044898, 377, 0, 0, 0.230769231, 0, 520],\
[0, 8, 49, 12, 3, 0.428529, 500, 0, 1, 0.076923077, 1, 370],\
[0, 22, 89, 2, 1, 0.819431, 500, 0, 0, 0.076923077, 0, 440],\
[3, 27, 63, 124, 0, 0.375306, 500, 0, 0, 0.076923077, 0, 880],\
[3, 20, 64, 18, 5, 0.439412, 500, 0, 1, 0.076923077, 3, 820],\
[1, 6, 34, 2, 12, 0.495654, 500, 0, 0, 0.076923077, 0, 653],\
[0, 14, 225, 0, 0, 1, 486, 0, 0, 0, 0, 1],\
[2, 8, 87, 32, 1, 0.829792, 500, 0, 0, 0.230769231, 0, 570],\
[2, 15, 46, 24, 4, 0.500442, 500, 0, 0, 0.153846154, 0, 568]])

# split datasets into independent and dependent variables
X_train, y_train = train_set[:, :-1], train_set[:, -1]    
X_test, y_test = test_set[:, :-1], test_set[:, -1]    

# feature scaling
sc = RobustScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.fit_transform(X_test)

# Linear model
reg = GetLinearModel(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("Linear", mae))

# Ridge Regression
reg = GetRidge(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("Ridge", mae))

# LASSO Regression
reg = GetLASSO(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("LASSO", mae))

# ElasticNet Regression
reg = GetElasticNet(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("ElasticNet", mae))

# Random Forest
reg = GetRandomForest(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("Random Forest", mae))


输出:

         Linear: 141.265089
          Ridge: 141.267797
          LASSO: 141.274700
     ElasticNet: 141.413544
  Random Forest:  90.776332