如何使Checkpoint将力矩和其他相关变量存储在tf.train Optimizers中

时间:2019-06-24 15:37:20

标签: python tensorflow

当我的代码由于某种原因在计算机上停止时遇到问题,因此我不得不重新启动代码并通过加载最新的检查点文件来继续训练过程。

我发现我加载的检查点前后的性能不一致,并且性能下降很多。

因此,由于我的代码使用tf.train.AdamOptimizer,所以我猜想检查点没有在前面的步骤中存储矩量向量和梯度,并且当我加载检查点时,矩量向量被初始化为零。 / p>

我正确吗?

是否有任何方法可以帮助在检查点中存储与Adamopotimizer相关的向量,以便如果我的机器再次停机,从最新检查点重新启动不会有任何影响?

谢谢!

1 个答案:

答案 0 :(得分:0)

出于好奇,我检查了它是否为真,并且一切似乎都工作正常:所有变量都显示在检查点中并已正确还原。亲自看看:

import tensorflow as tf
import sys
import numpy as np
from tensorflow.python.tools import inspect_checkpoint as inch


ckpt_path = "./tmp/model.ckpt"
shape = (2, 2)

def _print_all():
  for v in tf.all_variables():
    print('%20s' % v.name, v.eval())

def _model():
    a = tf.placeholder(tf.float32, shape)
    with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
      x = tf.get_variable('x', shape)

    loss = tf.matmul(a, tf.layers.batch_normalization(x))
    step = tf.train.AdamOptimizer(0.00001).minimize(loss)
    return a, step

def train():
    a, step = _model()
    saver = tf.train.Saver()

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      for i in range(10):
        _ = sess.run(step, feed_dict= {a:np.random.rand(*shape)})

      _print_all()
      print(saver.save(sess, ckpt_path))
      _print_all()


def check():
    a, step = _model()
    saver = tf.train.Saver()

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      _print_all()
      saver.restore(sess, ckpt_path)
      _print_all()


def checkpoint_list_vars(chpnt):
  """
  Given path to a checkpoint list all variables available in the checkpoint
  """
  from tensorflow.contrib.framework.python.framework import checkpoint_utils
  var_list = checkpoint_utils.list_variables(chpnt)
#   for v in var_list: print(v, var_val(v[0]))
#   for v in var_list: print(v)
  var_val('')

  return var_list

def var_val(name):
    inch.print_tensors_in_checkpoint_file(ckpt_path, name, True)

if 'restore' in sys.argv:
    check()
elif 'checkpnt' in sys.argv:
    checkpoint_list_vars(ckpt_path)
else:
    train()

将其存储为test.py并运行

>> python test.py
>> python test.py checkpnt
>> python test.py restore