Python Dataframe:每次运行Random Forest Regressor时的RMSE得分都不同

时间:2017-12-19 22:35:44

标签: python scikit-learn random-forest modeling seed

我目前使用以下代码运行随机林模型。我将random_state设置为100.

from sklearn.cross_validation import train_test_split

X_train_RIA_INST_PWM, X_test_RIA_INST_PWM, y_train_RIA_INST_PWM, y_test_RIA_INST_PWM = train_test_split(X_RIA_INST_PWM, Y_RIA_INST_PWM, test_size=0.3, random_state = 100)



# Random Forest Regressor for RIA_INST_PWM accounts  

import numpy as np
from sklearn.ensemble import RandomForestRegressor

regressor_RIA_INST_PWM = RandomForestRegressor(n_estimators=100, min_samples_split = 10)
regressor_RIA_INST_PWM.fit(X_RIA_INST_PWM, Y_RIA_INST_PWM)

print ("R^2 for training set:"),
print (regressor_RIA_INST_PWM.score(X_train_RIA_INST_PWM, y_train_RIA_INST_PWM))

print ('-'*50)

print ("R^2 for test set:"),
print (regressor_RIA_INST_PWM.score(X_test_RIA_INST_PWM, y_test_RIA_INST_PWM))

然后我使用以下代码计算预测值。

def predict_AUM(df, features, regressor):

    # Reset index for later merge of predicted target values with Account IDs
    df.reset_index();

    # Set predictor variables 
    X_Predict = df[features]

    # Clean inputs 
    X_Predict = X_Predict.replace([np.inf, -np.inf], np.nan)
    X_Predict = X_Predict.fillna(0)

    # Predict Current_AUM
    Y_AUM_Snapshot_1yr_Predict = regressor.predict(X_Predict)
    df['PREDICTED_SPAN'] = Y_AUM_Snapshot_1yr_Predict

    return df 

df_EVENT5_20 = predict_AUM(df_EVENT5_19, dfzip_features_AUM_RIA_INST_PWM, regressor_RIA_INST_PWM)

最后,我计算结果的RMSE:

from sklearn.metrics import mean_squared_error
from math import sqrt

rmse = sqrt(mean_squared_error(df_EVENT5_20['SPAN_DAYS'], df_EVENT5_20['PREDICTED_SPAN']))
rmse

每次我运行我的代码......我的RMSE都会改变。它从7.75到16.4不等。为什么会发生这种情况?每次运行代码时,如何才能拥有相同的RMSE?另外,如何针对RMSE优化我的模型?

1 个答案:

答案 0 :(得分:0)

你只接种了train_test_split,这确保了训练和测试集的数据的随机分配是可重复的。

顾名思义 Random ForestRegressor还包含算法中依赖于随机数的部分(例如,具体的不同数据部分或用于训练个体决策树的不同特征)。如果您想要可重复的结果,您也需要播种它。为此,你需要使用random_state初始化它:

import java.util.Arrays;
import java.util.List;
import java.util.Scanner;

public class Calculator {

    public static final String T_FACULTY = "faculty";
    public static final String T_DIVIDE = "divide";

    public static void faculty(List<String> args) {
        System.out.println("Hello from faculty : " + args);
        // more logic to go here
    }

    public static void divide(List<String> args) {
        System.out.println("Hello from divide : " + args);
        // do you thing here
    }

    public static void main(String[] args) {

        List<String> params = readFromConsole();
        if (!validateInput(params)) {
            System.out.println("Invalid options! usage: java Calculator <target> <options>");
            return;
        }
        String target = params.get(0);
        List<String> extraParams = params.subList(1, params.size());
        switch (target) {
        case T_FACULTY: {
            faculty(extraParams);
            break;
        }
        case T_DIVIDE: {
            divide(extraParams);
            break;
        }
        default:
            // Expected default behavior here
            throw new IllegalArgumentException("Operation unknown");
        }

    }

    private static List<String> readFromConsole() {
        // To be adapted
        return Arrays.asList(new Scanner(System.in).nextLine().split(" "));
    }

    private static boolean validateInput(List<String> params) {
        // Add more logic here
        // - validate the target function
        // - validate the additional parameters for that function
        return params.size() > 1;
    }
}