如何使用sklearn从决策树模型提高预测的准确性?

时间:2019-12-02 21:58:29

标签: python pandas machine-learning scikit-learn decision-tree

我使用sklearn在Python中创建了一个决策树模型,该模型从大型公共数据集中获取数据,这些数据将人为因素(年龄,bmi,性别,吸烟等)与保险公司支付的医疗费用相关联年。我以0.2的测试大小拆分了数据集,但是平均绝对误差和均方误差都非常高。我尝试进行不同的分割(.5,.8),但没有得到任何不同的结果。预测模型在某些方面似乎还不完善,但我不确定缺少哪些部分以及需要改进的地方。我已经附上了我的输出照片(通过IMGUR链接,因为我无法添加照片)以及我的代码,我感谢正确方向的任何指导!

https://imgur.com/a/6D74uB0

dataset = pd.read_csv('insurance.csv')

LE = LabelEncoder()
LE.fit(dataset.sex.drop_duplicates())
dataset.sex = LE.transform(dataset.sex)
LE.fit(dataset.smoker.drop_duplicates())
dataset.smoker = LE.transform(dataset.smoker)
LE.fit(dataset.region.drop_duplicates())
dataset.region = LE.transform(dataset.region)

print("Data Head")
print(dataset.head())
print()
print("Data Info")
print(dataset.info())
print()



for i in dataset.columns:
    print('Null Values in {i} :'.format(i = i) , dataset[i].isnull().sum())


X = dataset.drop('charges', axis = 1) 
y = dataset['charges'] 


X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.2, random_state=0)  

regressor = DecisionTreeRegressor()  
regressor.fit(X_train, y_train)  

y_pred = regressor.predict(X_test) 

df = pd.DataFrame({'Actual Value': y_test, 'Predicted Values': y_pred})  
print(df)

print('Mean Absolute Error:', metrics.mean_absolute_error(y_test, y_pred))
print('Mean Squared Error:', metrics.mean_squared_error(y_test, y_pred))
print('Root Mean Squared Error:', np.sqrt(metrics.mean_squared_error(y_test, y_pred)))

3 个答案:

答案 0 :(得分:1)

如果您还没有做的话可以做的某些事情:

  1. 在非分类列/功能上使用scikit-learn中的StandardScaler()
  2. 使用scikit-learn中的GridSearchCV来搜索适当的超参数,而不是手动进行。尽管选择手动进行操作可能会让您对哪些参数值可能有效有所了解。
  3. 仔细检查DecisionTreeRegressor的文档,以确保您的实现与文档一致。

看看是否有帮助。

答案 1 :(得分:0)

您可以使用xgboost,这是一种增强算法。

答案 2 :(得分:0)

引导聚集(https://en.wikipedia.org/wiki/Bootstrap_aggregating)是减少估计量方差的一种简便方法。如果您已经在使用sklearn回归器,则几乎不需要其他代码。下面是一个示例,说明了如何使用简单的袋装估算器来减少模型的方差:

from pymodbus.client.sync import ModbusSerialClient as ModbusClient
modbus = ModbusClient(method='rtu', port='/dev/tty.usbserial-AQ00BYCR', baudrate=9600, timeout=1)
modbus.connect()
test = modbus.read_holding_registers(1, 1, unit=1)
print (test)