我正在将KerasRegressor
与scikit-learn的MLPRegressor
进行比较,前者似乎在类似条件下表现更差。或者我错过了一些非常基本的东西?
输出如下:
MLPRegressor
============
r2_score: 0.9498015043626485
mean_absolute_error: 1.2375251067546698
mean_squared_error: 3.1270623572484233
median_absolute_error: 0.8439709890654576
mean_absolute_percentage_error: 5.716115241892469
KerasRegressor
==============
r2_score: -0.21815826719358755
mean_absolute_error: 6.447014662587636
mean_squared_error: 75.8838848484866
median_absolute_error: 4.708965668576953
mean_absolute_percentage_error: 27.069446468540352
此外,我收到了几个警告:
/home/bacalfa/.local/lib/python3.6/site-packages/sklearn/base.py:122: DeprecationWarning: Estimator KerasRegressor modifies parameters in __init__. This behavior is deprecated as of 0.18 and support for this behavior will be removed in 0.20.
相关模块的版本:
Keras==2.0.0
numpy==1.12.0
scipy==0.19.0
sklearn==0.18.1
tensorflow==1.0.1
tensorflow-gpu==1.0.1
Theano==0.9.0rc4
修改
按照指示here编辑文件~/.local/lib/python3.6/site-packages/keras/wrappers/scikit_learn.py
后,我的结果仍然不好。但是,在从mae
选项列表中删除mape
,msle
和loss
后(即,我只考虑mse
),我终于得到了更好的结果。我还添加了一些时序变量并更新了Gist文件。不确定为什么这些损失选项会导致问题......
MLPRegressor
============
Time: 63.069955586000106
Score: 0.9743374210376049
Parameters: {'solver': 'lbfgs', 'nesterovs_momentum': True, 'momentum': 0.95582961730267013, 'learning_rate_init': 6.9585889749848416e-05, 'learning_rate': 'invscaling', 'hidden_layer_sizes': (13,), 'beta_2': 0.54401214785115037, 'beta_1': 0.49106283542759821, 'alpha': 1.7176826485217918, 'activation': 'logistic'}
r2_score: 0.9498015043626485
mean_absolute_error: 1.2375251067546698
mean_squared_error: 3.1270623572484233
median_absolute_error: 0.8439709890654576
mean_absolute_percentage_error: 5.716115241892469
KerasRegressor
==============
Time: 111.47958513100002
Score: -7.324164294960475
Parameters: {'units': 9, 'optimizer': 'adamax', 'nesterov': False, 'momentum': 0.3903057105612151, 'lr': 0.012612536273812315, 'loss': 'mse', 'kernel_regularizer_weight': 0.049702772055240763, 'kernel_regularizer': 'l1_l2', 'kernel_initializer': 'normal', 'beta_2': 0.31764966762573782, 'beta_1': 0.055241532426840559, 'activation': 'relu'}
r2_score: 0.9353675562984648
mean_absolute_error: 1.4353763767225414
mean_squared_error: 4.026209932985799
median_absolute_error: 1.1217291129910958
mean_absolute_percentage_error: 6.831308588574378