我正在尝试创建一个优化器,该优化器可以基于损失来调整其学习率。开始使用这种方法后,我意识到了优化器的工作原理,应该改为学习速率调度器。
无论如何,我有兴趣找出为什么我当前的方法不起作用,因为我认为这最终将帮助我解决符号张量和数组之间的差异。
尤其是,代码运行并且损失正在减少,但是
import keras.backend as K
import numpy as np
from keras import callbacks, optimizers
from keras.models import Sequential
from keras.layers import Dense
from keras.legacy import interfaces
class AutoOptim(optimizers.Nadam):
def __init__(self,**kwargs):
super().__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.counter = K.variable(0, name='counter',dtype='int32')
self.lr_cand = K.variable(self.lr, name='lr_cand')
self.lastloss= K.variable(1e9, name='lastloss')
self.dloss = K.variable([1,0,0], name='dloss')
self.lr_update_facs = K.constant([1.0, 1.3, 1.0/1.3])
@interfaces.legacy_get_updates_support
def get_updates(self, loss, params):
dloss_update = K.update( self.dloss[ (self.counter+2) % 3 ] , self.dloss[ (self.counter+2) % 3 ] + (self.lastloss - loss) )
lastloss_save = K.update( self.lastloss, loss )
update_lr = K.update(self.lr_cand, K.switch( self.counter % 18, self.lr_cand, self.lr_cand * K.gather(self.lr_update_facs, K.argmax(self.dloss) ) ) )
reset_hist = K.update(self.dloss, K.switch( self.counter % 18, self.dloss, K.constant( [0.0,0.0,0.0]) ) )
lr_upd = K.update(self.lr, self.lr_cand * K.gather( self.lr_update_facs, self.counter % 3 ) )
super_updates = super().get_updates(loss,params)
counter_update = K.update(self.counter,self.counter+1 )
updates = [dloss_update, lastloss_save, update_lr, reset_hist, lr_upd, super_updates, counter_update]
return updates
model = Sequential()
model.add(Dense(1, input_dim=2, activation='relu'))
opt = AutoOptim()
model.compile(loss='mae', optimizer=opt, metrics=['accuracy'])
class My_Callback(callbacks.Callback):
def on_batch_end(self, batch, logs={}):
print(K.eval(self.model.optimizer.counter)-1, K.eval(self.model.optimizer.lr), K.eval(self.model.optimizer.lastloss), K.eval(self.model.optimizer.dloss))
#%%
X=np.random.rand(500,2)
Y=(X[:,0]+X[:,1])/2
model.fit(X,Y,epochs=1, callbacks=[My_Callback()], batch_size=10, verbose=0)
我希望看到学习率在3个值(当前,稍高,稍低)之间循环,并每18个周期设置一个新的“当前”值。
我的行为变得很不稳定,遗失物和lr没有按预期更新。
答案 0 :(得分:1)
代码格式问题:
由于存在缩进问题,导致get_updates()成为__init __()的一部分,因此未调用get_updates方法。因此,Nadam的get_updates()被调用。
修复缩进问题后,将调用AutoOptim的get_updates()。
您可以在get_updates()方法中打印这样的变量值:
print(f'Learning rate: {K.get_session().run([self.lr,self.lr_cand])}')