Keras-获得回调中批处理中使用的确切培训示例

时间:2018-11-26 21:29:32

标签: python keras neural-network

我在Keras中训练神经网络时遇到问题。每个时期,损失将稳定减少,达到大约1e-9,然后在时期的中间(可能在任何地方),损失跃升至5e-5,最终稳定在每个时期相同的最终损失。我相信这是由于我的数据集中的一些脏数据导致模型无法训练超过某个点,尽管我真的不确定。

要检验我的假设,我想创建一个自定义的Keras回调对象,该对象将确定一批批次后损失是否有足够大的跃迁,并指出是哪批次引起了损失。问题在于提供给batch的{​​{1}}参数只是批次 number ,实际上不是该批次中使用的训练示例。此外,传入的keras.callbacks.Callback.on_batch_end字典也只包含logsloss

这意味着我实际上无法确定哪些数据导致了损失的增加。有什么方法可以确定导致每个时期出现跳跃的确切训练示例?有什么方法可以在回调中访问它?

0 个答案:

没有答案