从Tensorflow中的检查点进行训练时重新初始化学习率

时间:2018-06-25 20:46:54

标签: tensorflow

从检查点加载tensorflow模型但将app.use('/stripe', function(req, res) { // whatever code res.render('/charge', {JS OBJECT}); }); 重置为从头开始而不是从检查点继续的正确方法是什么。

我正在使用以下代码初始化优化器。

learning_rate

我试图从检查点恢复learning_rate = tf.train.polynomial_decay(start_learning_rate, self.global_step, decay_steps, end_learning_rate, power=power,name="new_one2") opt = tf.train.AdamOptimizer(learning_rate) 并仅返回名称不包含var_list的那些变量,以便Adam从头开始重新初始化。但这无济于事。这是我使用的功能。

AdamOptimizer

这将给我除def optimistic_restore(self, save_file, ignore_vars=None, verbose=False, ignore_incompatible_shapes=False): """This function tries to restore all variables in the save file. This function ignores variables that do not exist or have incompatible shape. Raises TypeError if the there is a type mismatch for compatible shapes. session: tf.Session The tf session save_file: str Path to the checkpoint without the .index, .meta or .data extensions. ignore_vars: list, tuple or set of str These variables will be ignored. verbose: bool If True prints which variables will be restored ignore_incompatible_shapes: bool If True ignores variables with incompatible shapes. If False raises a runtime error f shapes are incompatible. """ def vprint(*args, **kwargs): if verbose: print(*args, flush=True, **kwargs) # def dbg(*args, **kwargs): print(*args, flush=True, **kwargs) def dbg(*args, **kwargs): pass if ignore_vars is None: ignore_vars = [] reader = tf.train.NewCheckpointReader(save_file) var_to_shape_map = reader.get_variable_to_shape_map() var_list = [] for key in sorted(var_to_shape_map): if not 'Adam' in key: var_list.append(key) return var_list 以外的所有变量。然后,我将结果的变量列表传递给Adam。如下所述。

saver

但这仍然无济于事。我的学习率无法恢复。有人可以帮我吗?

1 个答案:

答案 0 :(得分:0)

如果其他人也正在解决问题。

解决方案非常简单。我只是将global_step添加到我的polynomial_decay学习率中作为占位符。即

custom_lr = tf.placeholder(tf.int32)
learning_rate = tf.train.polynomial_decay(start_learning_rate, custom_lr,
                                              decay_steps, end_learning_rate,
                                              power=power,name="new_one2")

只需在feed字典中传递learning_rate。

for step in range(0, 1000):
    sess.run(feed_dict={ custom_lr: step }