我在tensorflow.python.keras
与tensorflow.contrib.opt
中遇到了Nadam优化器的实现差异。
在https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/python/keras/optimizers.py#L639
我们在Nadam
类的get_updates函数中包含以下几行:
...
# Due to the recommendations in [2], i.e. warming momentum schedule
momentum_cache_t = self.beta_1 * (
1. - 0.5 *
(math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
momentum_cache_t_1 = self.beta_1 * (
1. - 0.5 *
(math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
m_schedule_new = self.m_schedule * momentum_cache_t
m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
self.updates.append((self.m_schedule, m_schedule_new))
...
,self.schedule_decay的默认值为0.004。上面摘录中的注释引用了http://www.cs.toronto.edu/~fritz/absps/momentum.pdf。
另一方面,在tensorflow.contrib.opt中的实现 https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/opt/python/training/nadam_optimizer.py 不包括任何衰减。
我的结果明显不同于使用另一种方法,即与实施tensorflow.python.keras.optimizers.Nadam
相比,tensorflow.contrib.opt.NadamOptimizer
实现的损失要低得多
有什么简单的方法可以将tf.contrib.opt.NadamOptimzier
扩展到在keras中实现的功能?