对于训练和测试数据,sklearn Random Forest准确度得分相同

时间:2017-05-14 18:31:10

标签: python machine-learning scikit-learn random-forest metrics

我试图建立电动汽车充电事件数据的分类模型。我想预测充电站是否可以在给定的时间点使用。我有以下代码:

from sklearn.ensemble import RandomForestClassifier
import pandas as pd

raw_data = pd.read_csv('C:/temp/sample_dataset.csv')
raw_test = pd.read_csv('C:/temp/sample_dataset_test.csv')
print ('raw data shape: ', raw_test.shape)

#choose which columns to dummify
X_vars = ['station_id', 'day_of_week', 'epoch', 'station_city',
 'station_county', 'station_zip', 'port_level', 'perc_local_occupancy',
 'ports_at_station', 'avg_charge_duration']
y_var = ['target_variable']
categorical_vars = ['station_id','station_city','station_county']

#split X and y in training and test
X_train = raw_data.loc[:,X_vars]
y_train = raw_data.loc[:,y_var]
X_test = raw_test.loc[:,X_vars]
y_test = raw_test.loc[:,y_var]

#make dummy variables
X_train = pd.get_dummies(X_train, columns = categorical_vars )
X_test = pd.get_dummies(X_test, columns=categorical_vars)

print('train size', X_train.shape, '\ntest size', X_test.shape)

# Train uncalibrated random forest classifier on whole train and evaluate on test data
clf = RandomForestClassifier(n_estimators=100, max_depth=2)
clf.fit(X_train, y_train.values.ravel())

print ('RF accuracy: TRAINING', clf.score(X_train,y_train))
print ('RF accuracy: TESTING', clf.score(X_test,y_test))

结果

raw data shape:  (1000000, 15)
train size (1000000, 125) 
test size (1000000, 125)
RF accuracy: TRAINING 0.831456
RF accuracy: TESTING 0.831456

我的问题是为什么培训和测试准确度完全相同?我经常这么多次,它总是完全一样。有任何想法吗? (我已检查确保原始数据不同)

2 个答案:

答案 0 :(得分:1)

您的代码中只有一个拼写错误,因为每次选择所有行:

#split X and y in training and test
X_train = raw_data.loc[:,X_vars] 
y_train = raw_data.loc[:,y_var]
X_test = raw_test.loc[:,X_vars]
y_test = raw_test.loc[:,y_var]

您应该通过某些索引单独索引它们,例如:X_train = raw_data.loc[:idx,X_vars]

答案 1 :(得分:0)

您是否可能在列车和测试文件中使用相同的数据集?

如果是相同的数据,那么最好将数据拆分为train并使用train_test_split模块进行测试。

http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html