如何在TensorFlow中使用EMA衰减?

时间:2019-03-16 08:18:01

标签: python tensorflow moving-average

最近,我开始在TensorFlow中使用EMA方法(如下所示)

我创建了一个Network类来构建神经网络体系结构,然后使用该类的成员函数定义了两个模型。现在,我希望model2使用model1的EMA参数。我指的是How to use Exponential Moving Average in Tensorflow

但是model2的参数没有随model1改变。

哪一部分错了,如何修改? 非常感谢!!

    with tf.variable_scope('model') as scope:
        '''
        Build is a member function of the class
        '''
        model= Network(self.config)
        model.build(net_input=data_x, net_label=data_y, net_tag=data_tag)
        self.net = model.net
        self.trainable_list = model.trainable_list
        self.variables = model.variables

    ema = tf.train.ExponentialMovingAverage(0.9)
    var_class = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope.name)
    update_op = ema.apply(var_class)
    tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op)

    def use_ema_variables(getter, name, *_, **__):
        var = getter(name, *_, **__)
        ema_var = ema.average(var)
        return ema_var if ema_var else var

    with tf.variable_scope('ema_model',custom_getter=use_ema_variables) as scope:
        # ema_model
        ema_model= Network(self.config)
        ema_model.build(net_input=data_x, net_label=data_y, net_tag=data_tag)
        '''
        ema_model don't get ema parameters
        '''

0 个答案:

没有答案