使用Tensorflow的Keras自定义优化器

时间:2019-04-08 02:08:55

标签: tensorflow keras

我想在Keras中实现SPSA优化器。以前,我从https://github.com/fraunhofer-iais/tensorflow_spsa使用TF的SPSA实现 我发现我们可以在model.compile(...)期间直接向Keras提供TF优化器。但是,要使优化程序正常工作,它应该坚持使用标准的优化程序实施,并定义正确的compute_gradients()apply gradients()

我尝试如下修改代码:

import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
from tensorflow.contrib import graph_editor as ge
from tensorflow.contrib.distributions import Bernoulli

class SimultaneousPerturbationOptimizer(tf.train.Optimizer):

    def __init__(self, inputs=None, graph=None, a=0.01, c=0.01, alpha=1.0, gamma=0.4, use_locking=False, name="SPSA"):
        super(SimultaneousPerturbationOptimizer, self).__init__(use_locking, name)
        self.work_graph = tf.get_default_graph() if graph is None else graph                                # the graph to work on
        self.tvars = [var.name.split(':')[0] for var in tf.trainable_variables()]  # list of names of trainable variables
        self.inputs = inputs
        self.num_params = 0
        self.loss = None

        # optimizer parameters
        self.a = tf.constant(a, dtype=tf.float32, name = "SPSA_a" )
        self.c = tf.constant(c, dtype=tf.float32, name = "SPSA_c" )
        self.alpha = tf.constant(alpha, dtype=tf.float32, name = "SPSA_alpha" )
        self.gamma = tf.constant(gamma, dtype=tf.float32, name = "SPSA_gamma" )

        self.global_step_tensor = tf.Variable(0, name='global_step', trainable=False)
        # print( "SPSA:   a = {:4.3f}  c = {:4.3f}  alpha = {:4.3f}  gamma = {:4.3f}".format(a,c,alpha,gamma) )
        # print( "Trainable Variables:" )
        # for var in tf.trainable_variables():
        #     var_name = var.name.split(':')[0]
        #     print( "    {} => {}".format(var.name,var_name) )


    def _clone_model(self, model, perturbations, dst_scope):
        ''' make a copy of model and connect the resulting sub-graph to
            input ops of the original graph and parameter assignments by
            perturbator.    
        '''
        def not_placeholder_or_trainvar_filter(op):
            if op.type == 'Placeholder':              # evaluation sub-graphs will be fed from original placeholders
                return False
            for var_name in self.tvars:
                # print( var_name, type(var_name) )
                if op.name.startswith(var_name):      # remove Some/Var/(read,assign,...) -- will be replaced with perturbations
                    return False
            return True

        ops_without_inputs = ge.filter_ops(model.ops, not_placeholder_or_trainvar_filter)
        # remove init op from clone if already present
        try:
            ops_without_inputs.remove(self.work_graph.get_operation_by_name("init"))
        except:
            pass
        clone_sgv = ge.make_view(ops_without_inputs)
        clone_sgv = clone_sgv.remove_unused_ops(control_inputs=True)

        input_replacements = {}
        for t in clone_sgv.inputs:
            if t.name in perturbations.keys():                  # input from trainable var --> replace with perturbation
                input_replacements[t] = perturbations[t.name]
            else:                                               # otherwise take input from original graph
                input_replacements[t] = self.work_graph.get_tensor_by_name(t.name)
        return ge.copy_with_input_replacements(clone_sgv, input_replacements, dst_scope=dst_scope)


    def _mul_dims(self, shape):
        n = 1
        for d in shape:
            n *= d.value
        return n


    # def minimize( self, loss, var_list = None, global_step = None ):
    #     orig_graph_view = None
    #     trainable_vars = var_list if var_list != None else tf.trainable_variables()
    #     if self.inputs is not None:
    #         seed_ops = [t.op for t in self.inputs]
    #         result = list(seed_ops)
    #         wave = set(seed_ops)
    #         while wave:                 # stolen from grap_editor.select
    #             new_wave = set()
    #             for op in wave:
    #                 for new_t in op.outputs:
    #                     if new_t == loss:
    #                         continue
    #                     for new_op in new_t.consumers():
    #                         #if new_op not in result and is_within(new_op):
    #                         if new_op not in result:
    #                             new_wave.add(new_op)
    #             for op in new_wave:
    #                 if op not in result:
    #                     result.append(op)
    #             wave = new_wave
    #         orig_graph_view = ge.sgv(result)
    #     else:
    #         orig_graph_view = ge.sgv(self.work_graph)

    #     self.global_step_tensor = tf.Variable(0, name='global_step', trainable=False) if global_step is None else global_step

    #     # Perturbations
    #     deltas = {}
    #     n_perturbations = {}
    #     p_perturbations = {}
    #     with tf.name_scope("Perturbator"):
    #         self.c_t = tf.div( self.c,  tf.pow(tf.add(tf.cast(self.global_step_tensor, tf.float32),
    #                                           tf.constant(1, dtype=tf.float32)), self.gamma), name = "SPSA_ct" )
    #         # self.c_t = 0.00 #MOD
    #         for var in trainable_vars:
    #             self.num_params += self._mul_dims(var.get_shape())
    #             var_name = var.name.split(':')[0]
    #             random = Bernoulli(tf.fill(var.get_shape(), 0.5), dtype=tf.float32)
    #             deltas[var] = tf.subtract( tf.constant(1, dtype=tf.float32),
    #                                 tf.scalar_mul(tf.constant(2, dtype=tf.float32),random.sample(1)[0]), name = "SPSA_delta" )
    #             c_t_delta = tf.scalar_mul( tf.reshape(self.c_t, []), deltas[var] )
    #             n_perturbations[var_name+'/read:0'] = tf.subtract( var, c_t_delta, name = "perturb_n" )
    #             p_perturbations[var_name+'/read:0'] = tf.add(var, c_t_delta, name = "perturb_p" )
    #     # print("{} parameters".format(self.num_params))

    #     # Evaluator
    #     with tf.name_scope("Evaluator"):
    #         _, self.ninfo = self._clone_model(orig_graph_view, n_perturbations, 'N_Eval')
    #         _, self.pinfo = self._clone_model(orig_graph_view, p_perturbations, 'P_Eval')

    #     # Weight Updater
    #     optimizer_ops = []
    #     with tf.control_dependencies([loss]):
    #         with tf.name_scope('Updater'):
    #             a_t = self.a / (tf.pow(tf.add(tf.cast(self.global_step_tensor, tf.float32),
    #                                          tf.constant(1, dtype=tf.float32)), self.alpha))
    #             # a_t = 0.00 #MOD
    #             for var in trainable_vars:
    #                 l_pos = self.pinfo.transformed( loss )
    #                 l_neg = self.ninfo.transformed( loss )
    #                 # print( "l_pos: ", l_pos)
    #                 # print( "l_neg: ", l_neg)
    #                 ghat = (l_pos - l_neg) / (tf.constant(2, dtype=tf.float32) * self.c_t * deltas[var])
    #                 optimizer_ops.append(tf.assign_sub(var, a_t*ghat))
    #     grp = control_flow_ops.group(*optimizer_ops)
    #     with tf.control_dependencies([grp]):
    #          tf.assign_add(self.global_step_tensor, tf.constant(1, dtype=self.global_step_tensor.dtype))

    #     return grp

    def compute_gradients( self, loss = None, var_list = None ):
        # print( "this" )

        orig_graph_view = None

        trainable_vars = var_list if var_list != None else tf.trainable_variables()

        if self.inputs is not None:
            seed_ops = [t.op for t in self.inputs]
            result = list(seed_ops)
            wave = set(seed_ops)
            while wave:                 # stolen from grap_editor.select
                new_wave = set()
                for op in wave:
                    for new_t in op.outputs:
                        if new_t == loss:
                            continue
                        for new_op in new_t.consumers():
                            #if new_op not in result and is_within(new_op):
                            if new_op not in result:
                                new_wave.add(new_op)
                for op in new_wave:
                    if op not in result:
                        result.append(op)
                wave = new_wave
            orig_graph_view = ge.sgv(result)
        else:
            orig_graph_view = ge.sgv(self.work_graph)

        self.global_step_tensor = tf.Variable(0, name='global_step', trainable=False)

        # Perturbations
        deltas = {}
        n_perturbations = {}
        p_perturbations = {}
        with tf.name_scope("Perturbator"):
            self.c_t = tf.div( self.c,  tf.pow(tf.add(tf.cast(self.global_step_tensor, tf.float32),
                                              tf.constant(1, dtype=tf.float32)), self.gamma), name = "SPSA_ct" )
            # self.c_t = 0.00 #MOD
            for var in trainable_vars:
                self.num_params += self._mul_dims(var.get_shape())
                var_name = var.name.split(':')[0]
                random = Bernoulli(tf.fill(var.get_shape(), 0.5), dtype=tf.float32)
                deltas[var] = tf.subtract( tf.constant(1, dtype=tf.float32),
                                    tf.scalar_mul(tf.constant(2, dtype=tf.float32),random.sample(1)[0]), name = "SPSA_delta" )
                c_t_delta = tf.scalar_mul( tf.reshape(self.c_t, []), deltas[var] )
                n_perturbations[var_name+'/read:0'] = tf.subtract( var, c_t_delta, name = "perturb_n" )
                p_perturbations[var_name+'/read:0'] = tf.add(var, c_t_delta, name = "perturb_p" )
        # print("{} parameters".format(self.num_params))

        # Evaluator
        with tf.name_scope("Evaluator"):
            _, self.ninfo = self._clone_model(orig_graph_view, n_perturbations, 'N_Eval')
            _, self.pinfo = self._clone_model(orig_graph_view, p_perturbations, 'P_Eval')

        # Weight Updater
        optimizer_ops = []
        grads_op = []
        with tf.control_dependencies([loss]):
            with tf.name_scope('Updater'):
                a_t = self.a / (tf.pow(tf.add(tf.cast(self.global_step_tensor, tf.float32),
                                             tf.constant(1, dtype=tf.float32)), self.alpha))
                # a_t = 0.00 #MOD
                for var in trainable_vars:
                    l_pos = self.pinfo.transformed( loss )
                    l_neg = self.ninfo.transformed( loss )
                    # print( "l_pos: ", l_pos)
                    # print( "l_neg: ", l_neg)
                    ghat = (l_pos - l_neg) / (tf.constant(2, dtype=tf.float32) * self.c_t * deltas[var])
                    grads_op.append( ghat )

        #             optimizer_ops.append(tf.assign_sub(var, a_t*ghat))
        # grp = control_flow_ops.group(*optimizer_ops)
        # with tf.control_dependencies([grp]):
        #      tf.assign_add(self.global_step_tensor, tf.constant(1, dtype=self.global_step_tensor.dtype))

        # return grp
        grads_and_vars = list( zip(grads_op,trainable_vars) )

        return grads_and_vars

    def apply_gradients(self, grad_and_vars, global_step=None):
        self.global_step_tensor = tf.Variable(0, name='global_step', trainable=False) if global_step is None else global_step
        a_t = self.a / (tf.pow(tf.add(tf.cast(self.global_step_tensor, tf.float32),
                                             tf.constant(1, dtype=tf.float32)), self.alpha))
        optimizer_ops = []
        for (g,v) in grad_and_vars:
            optimizer_ops.append(tf.assign_sub(v,a_t*g))
        grp = control_flow_ops.group(*optimizer_ops)
        with tf.control_dependencies([grp]):
             tf.assign_add(self.global_step_tensor, tf.constant(1, dtype=self.global_step_tensor.dtype))

        return grp


    def get_clones(self, op):
        return (self.ninfo.transformed(op), self.pinfo.transformed(op))



    def get_clones(self, op):
        return (self.ninfo.transformed(op), self.pinfo.transformed(op))

但是我遇到了这个错误:

---------------------------------------------------------------------------

FailedPreconditionError                   Traceback (most recent call last)

<ipython-input-16-fde0f2b12519> in <module>()
----> 1 train_log = model2.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test,y_test))

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, max_queue_size, workers, use_multiprocessing, **kwargs)
    878           initial_epoch=initial_epoch,
    879           steps_per_epoch=steps_per_epoch,
--> 880           validation_steps=validation_steps)
    881 
    882   def evaluate(self,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training_arrays.py in model_iteration(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, mode, validation_in_fit, **kwargs)
    327 
    328         # Get outputs.
--> 329         batch_outs = f(ins_batch)
    330         if not isinstance(batch_outs, list):
    331           batch_outs = [batch_outs]

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py in __call__(self, inputs)
   3074 
   3075     fetched = self._callable_fn(*array_vals,
-> 3076                                 run_metadata=self.run_metadata)
   3077     self._call_fetch_callbacks(fetched[-len(self._fetches):])
   3078     return nest.pack_sequence_as(self._outputs_structure,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
   1437           ret = tf_session.TF_SessionRunCallable(
   1438               self._session._session, self._handle, args, status,
-> 1439               run_metadata_ptr)
   1440         if run_metadata:
   1441           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
    526             None, None,
    527             compat.as_text(c_api.TF_Message(self.status.status)),
--> 528             c_api.TF_GetCode(self.status.status))
    529     # Delete the underlying status object from memory otherwise it stays alive
    530     # as there is a reference to status from this from the traceback due to

FailedPreconditionError: Attempting to use uninitialized value training/TFOptimizer/global_step
     [[{{node training/TFOptimizer/global_step/read}}]]

我对tensorflow低级别的经验很少,无法理解问题所在。

我该如何解决?

0 个答案:

没有答案