我正在使用Keras API(与TF2.0-beta
一起实现自定义图层)。我想在计算中使用时期号,以便随时间衰减参数(意思是-在call()
方法中)。
我习惯tf.get_global_step()
,但了解TF不推荐使用所有全局范围,这绝对是有充分理由的。
如果我有模型实例,可以使用model.optimizer.iterations
,但是我不确定在实现图层时如何获取父模型的实例。
我是否有任何办法做到这一点,或者唯一的办法就是让图层公开一个回调,该回调将更新我要衰减的参数?还有其他想法吗?理想情况下,不会让该层的用户意识到内部细节(这就是为什么我不喜欢Callback方法的原因-用户必须将它们添加到模型中)。