如何从历史记录访​​问指标

时间:2019-10-24 13:23:00

标签: python keras neural-network

对于回归问题,我想比较一些指标,但是我只能从历史中获得accuracy,这对回归目的毫无意义。如何获得其他指标,例如mean_squared_error等?

create_model(...)
    input_layer = ...
    output_laye = ...
    model = Model(input_layer, output_layer)
    model.compile(loss='mean_squared_error', optimizer=optimizer, metrics=['accuracy'])
    return model

model = KerasRegressor(build_fn=create_model, verbose=0)

batch_size = [1, 2]
epochs = [1, 2]
optimizer = ['Adam', 'sgd']    
param_grid = dict(batch_size=batch_size
                     , optimizer = optimizer
                     )

grid_obj  = RandomizedSearchCV(estimator=model 
                    , param_grid=hypparas
                    , n_jobs=1
                    , cv = 3
                    , scoring = ['explained_variance', 'neg_mean_squared_error', 'r2']
                    , refit = 'neg_mean_squared_error'
                    , return_train_score=True
                    , verbose = 2
                    )

grid_result = grid_obj.fit(X_train1, y_train1)

X_train1, X_val1, y_train1, y_val1 = train_test_split(X_train1, y_train1, test_size=0.2, shuffle=False)

grid_best = grid_result.best_estimator_
history = grid_best.fit(X_train1, y_train1
                        , validation_data=(X_val1, y_val1)
                        )

print(history.history.keys())
> dict_keys(['val_loss', 'val_accuracy', 'loss', 'accuracy'])

我见过https://stackoverflow.com/a/50137577/6761328,例如

history.history['accuracy']

可以,但是我无法访问mean_squared_error或其他内容:

history.history['neg_mean_squared_error']
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-473-eb96973bf014> in <module>
----> 1 history.history['neg_mean_squared_error']

KeyError: 'neg_mean_squared_error'

这个问题终于是对How to compare different metrics?的跟进,因为我认为这个问题是对另一个问题的解答。

1 个答案:

答案 0 :(得分:2)

在独立的Keras(不确定scikit-learn包装器)中,history.history['loss'](或对于验证集分别为val_loss)可以完成这项工作。

在这里,'loss''val_loss';给予

print(history.history.keys())

查看适用于您的情况的密钥,您将在其中找到丢失所需的密钥(甚至可以相同,即'loss''val_loss')。

请注意,您应该从模型编辑中完全删除metrics=['accuracy']-正如您正确指出的那样,在回归设置中准确性没有意义(您可能要检查What function defines accuracy in Keras when the loss is mean squared error (MSE)?)。