我是python和深度学习的新手,我正在尝试使用在github上找到的源代码学习从Adam到SGD优化器的两阶段切换的行为。优化器设计为在满足触发条件时切换到SGD。触发条件如下:
cond_update = gen_math_ops.logical_or(gen_math_ops.logical_and(gen_math_ops.logical_and( self.iterations > 1, lg_err < 1e-2 ), lam_t > 0 ), cond )[0]
通过反复试验,我能够找到一个允许进行切换的1e-2值。我想提取lg_err的值来绘制并研究其在整个迭代过程中的值。我在原始代码中添加了几行,并进行了10个时期的训练,以尝试提取lg_err的值,但是我得到了以下csv输出文件。csv output
这是自定义优化器的修改后的源代码,我添加到第8、32、71-74行:
from tensorflow.python.framework import ops
from tensorflow.python.keras import optimizers
from tensorflow.python.keras import backend as K
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
import csv
class SWATS(optimizers.Optimizer):
def __init__(self,lr=0.001,lr_boost=10.0,beta_1=0.9,beta_2=0.999,epsilon=None,decay=0.,amsgrad=False,**kwargs):
super(SWATS, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.lr = K.variable(lr, name='lr')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
self.amsgrad = amsgrad
def get_updates(self, loss, params):
def m_switch(pred, tensor_a, tensor_b):
def f_true(): return tensor_a
def f_false(): return tensor_b
return control_flow_ops.cond(pred, f_true, f_false, strict=True)
grads = self.get_gradients(loss, params)
self.updates = []
Trial = [] #create trial
lr = self.lr
if self.initial_decay > 0:
lr = lr * ( 1. / (1. + self.decay * math_ops.cast(self.iterations,K.dtype(self.decay))) )
with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]):
t = math_ops.cast(self.iterations, K.floatx())
lr_bc = gen_math_ops.sqrt(1. - math_ops.pow(self.beta_2, t)) / (1. - math_ops.pow(self.beta_1, t))
ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
lams = [K.zeros(1, dtype=K.dtype(p)) for p in params]
conds = [K.variable(False, dtype='bool') for p in params]
if self.amsgrad:
vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
else:
vhats = [K.zeros(1) for _ in params]
self.weights = [self.iterations] + ms + vs + vhats + lams + conds
for p, g, m, v, vhat, lam, cond in zip(params, grads, ms, vs, vhats, lams, conds):
beta_g = m_switch(cond, 1.0, 1.0 - self.beta_1)
m_t = (self.beta_1 * m) + beta_g * g
v_t = (self.beta_2 * v) + (1. - self.beta_2) * math_ops.square(g)
if self.amsgrad:
vhat_t = math_ops.maximum(vhat, v_t)
p_t_ada = lr_bc * m_t / (gen_math_ops.sqrt(vhat_t) + self.epsilon)
self.updates.append(state_ops.assign(vhat, vhat_t))
else:
p_t_ada = lr_bc * m_t / (gen_math_ops.sqrt(v_t) + self.epsilon)
gamma_den = math_ops.reduce_sum(p_t_ada * g)
gamma = math_ops.reduce_sum(gen_math_ops.square(p_t_ada)) / (math_ops.abs(gamma_den) + self.epsilon) * (gen_math_ops.sign(gamma_den) + self.epsilon)
lam_t = (self.beta_2 * lam) + (1. - self.beta_2) * gamma
lam_prime = lam / (1. - math_ops.pow(self.beta_2, t))
lam_t_prime = lam_t / (1. - math_ops.pow(self.beta_2, t))
lg_err = math_ops.abs( lam_t_prime - gamma )
# extract lg_err values into array
Trial.append(lg_err)
with open('lg_err_values', 'w', newline='') as myfile:
wr = csv.writer(myfile, quoting=csv.QUOTE_ALL)
wr.writerow(Trial)
cond_update = gen_math_ops.logical_or(gen_math_ops.logical_and(gen_math_ops.logical_and( self.iterations > 1, lg_err < 1e-4 ), lam_t > 0 ), cond )[0]
lam_update = m_switch(cond_update, lam, lam_t)
self.updates.append(state_ops.assign(lam, lam_update))
self.updates.append(state_ops.assign(cond, cond_update))
p_t_sgd = (1. - self.beta_1) * lam_prime * m_t
self.updates.append(state_ops.assign(m, m_t))
self.updates.append(state_ops.assign(v, v_t))
new_p = m_switch(cond, p - lr * p_t_sgd, p - lr * p_t_ada)
# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
self.updates.append(state_ops.assign(p, new_p))
return self.updates
def get_config(self):
config = {
'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'decay': float(K.get_value(self.decay)),
'epsilon': self.epsilon,
'amsgrad': self.amsgrad
}
base_config = super(SWATS, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
如果我冒犯了这个社区的任何礼节,我事先表示歉意。