我正在使用Model
训练TensorFlow Keras Model.fit()
。在每次使用TensorFlow的on_train_batch_end()
语法进行批处理之后,我还使用回调来记录我的训练准确性指标。另外,我正在使用另一个回调来每1000批次运行Model.evaluate()
,以计算验证集的准确性并更新logs
期间围绕回调传递的Model.fit()
字典。
查看记录的度量标准与批号之间的关系,结果非常令人困惑。 Model.evaluate()
运行后,训练准确性经历了明显的“颠簸”,最初触发了所记录的训练准确性的迅速提高,随后触发了显着的跌落训练准确性,随后恢复速度较慢(参见附图)。>
我的猜测是,这与Model.evaluate()对reset_metrics()
的调用有关,后者对每个度量标准进行循环并调用reset_states()方法。我无法弄清楚reset_states()
在做什么,如果这与我观察到的行为有关。它似乎与CategoricalAccuracy
的{{3}}父类有关。在TensorFlow文档中,我还找不到任何有用的东西。
在Model.fit()
期间显示的指标实际上是某种形式的移动平均值,而不是按批处理指标吗?在这种情况下,reset_states()
方法将重置移动平均值,可能会产生颠簸的行为。
任何对TensorFlow的内部运作有更好了解的人都能提供帮助吗?