我正在尝试使用tensorflow.keras.callbacks.EarlyStopping
在我的DNNClassifier
模型中使用钩子,但是我不知道要在monitor
中放入什么。该文档在这里并没有帮助。
通过查看代码,将softmax交叉熵用作损失函数,但对于DNNRegressor
,损失节点为dnn/head/weighted_loss/Sum
(根据this thread)。我已经尝试过启动并运行Tensorboard,但是我无法运行,并且保存的模型中的import script在我的机器上同样存在缺陷。
有没有办法找出DNNClassifier
损失的节点是什么?
答案 0 :(得分:1)
monitor
不是指图形节点或图层,而是指损耗或度量值。实际上,可以使用logs
字典中存在的任何值:https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/python/keras/callbacks.py#L676
例如,可以使用logs
来检查CSVLogger
中的值而无需调试:
csv_logger = CSVLogger(filename=os.path.join(args.log_dir, 'train.csv'), separator=',', append=False)
如果无法写入文件,则可以将logs
中的所有内容打印到标准输出:
mycallback = LambdaCallback(on_epoch_end=lambda epoch, logs: print('\n'.join(['{}: {}'.format(k, v) for k, v in logs.items()])))
如果logs
中没有该指标,则可以使用LambdaCallback将该指标放在此处。例如:
eval_callback = LambdaCallback(on_epoch_end=lambda epoch, logs: logs.update({'metric_name': get_metric_value()}))
early_stopping = EarlyStopping(monitor='metric_name', min_delta=0.0, patience=10, verbose=1, mode='min')