在Keras回调中访问变量

时间:2016-10-16 22:11:12

标签: python keras

所以我实施了CNN。我已经确认有效的自定义回调,但我有一个问题。

这是一个示例输出。 迭代5的示例(为简单起见,批量大小为10,000)

50000/60000 [========================>.....] - ETA: 10s ('new lr:', 0.01)

('accuracy:', 0.70)

我有2个回调(测试工作如输出中所示): (1)改变每次迭代的学习率。 (2)在每次迭代时打印精度。

我有一个外部脚本,通过考虑准确性来确定学习率。

问题: 如何使每次迭代的准确性可用,以便外部脚本可以访问它?实质上是每次迭代时的可访问变量。只有在AccuracyCallback.accuracy

的过程结束后,我才能访问它

问题 我可以通过不断变化的学习率。但是,如果在每次迭代时以可访问变量的形式传递准确度,我如何获得准确性?

示例 我的外部脚本确定迭代1:0.01的学习率。如何在迭代1而不是print语句中将准确性作为外部脚本中的可访问变量?

1 个答案:

答案 0 :(得分:0)

你可以create your own callback

class AccCallback(keras.callbacks.Callback):

    def on_batch_end(self, batch, logs={}):
        accuracy = logs.get('acc')
        # pass accuracy to your 'external' script and set new lr here

为了使logs.get('acc')起作用,您必须告诉Keras监控它:

model.compile(optimizer='...', loss='...', metrics=['accuracy'])

最后,请注意此处accuracy的类型为ndarray。如果它引起任何问题,我建议包装它:float(accuracy)