最近,我开始在TensorFlow中使用EMA方法(如下所示)
我创建了一个Network类来构建神经网络体系结构,然后使用该类的成员函数定义了两个模型。现在,我希望model2使用model1的EMA参数。我指的是How to use Exponential Moving Average in Tensorflow
但是model2的参数没有随model1改变。
哪一部分错了,如何修改? 非常感谢!!
with tf.variable_scope('model') as scope:
'''
Build is a member function of the class
'''
model= Network(self.config)
model.build(net_input=data_x, net_label=data_y, net_tag=data_tag)
self.net = model.net
self.trainable_list = model.trainable_list
self.variables = model.variables
ema = tf.train.ExponentialMovingAverage(0.9)
var_class = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope.name)
update_op = ema.apply(var_class)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op)
def use_ema_variables(getter, name, *_, **__):
var = getter(name, *_, **__)
ema_var = ema.average(var)
return ema_var if ema_var else var
with tf.variable_scope('ema_model',custom_getter=use_ema_variables) as scope:
# ema_model
ema_model= Network(self.config)
ema_model.build(net_input=data_x, net_label=data_y, net_tag=data_tag)
'''
ema_model don't get ema parameters
'''