DNNClassifier的EarlyStopping损失

时间:2019-01-29 14:15:23

标签: python tensorflow

我正在尝试使用tensorflow.keras.callbacks.EarlyStopping在我的DNNClassifier模型中使用钩子,但是我不知道要在monitor中放入什么。该文档在这里并没有帮助。

通过查看代码,将softmax交叉熵用作损失函数,但对于DNNRegressor,损失节点为dnn/head/weighted_loss/Sum(根据this thread)。我已经尝试过启动并运行Tensorboard,但是我无法运行,并且保存的模型中的import script在我的机器上同样存在缺陷。

有没有办法找出DNNClassifier损失的节点是什么?

1 个答案:

答案 0 :(得分:1)

monitor不是指图形节点或图层,而是指损耗或度量值。实际上,可以使用logs字典中存在的任何值:https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/python/keras/callbacks.py#L676

例如,可以使用logs来检查CSVLogger中的值而无需调试:

csv_logger = CSVLogger(filename=os.path.join(args.log_dir, 'train.csv'), separator=',', append=False)

如果无法写入文件,则可以将logs中的所有内容打印到标准输出:

mycallback = LambdaCallback(on_epoch_end=lambda epoch, logs: print('\n'.join(['{}: {}'.format(k, v) for k, v in logs.items()])))

如果logs中没有该指标,则可以使用LambdaCallback将该指标放在此处。例如:

eval_callback = LambdaCallback(on_epoch_end=lambda epoch, logs: logs.update({'metric_name': get_metric_value()}))

early_stopping = EarlyStopping(monitor='metric_name', min_delta=0.0, patience=10, verbose=1, mode='min')