TensorFlow v1.10 +:get_checkpoint_state Estimator API指定特定检查点的last_filename的期望值是多少

时间:2019-06-27 08:19:36

标签: python tensorflow tensorflow-estimator

函数tf.train.get_checkpoint_state(第246行)具有以下函数签名(在文件checkpoint_management中定义)

def get_checkpoint_state(checkpoint_dir, latest_filename=None):
  """Returns CheckpointState proto from the "checkpoint" file.
  If the "checkpoint" file contains a valid CheckpointState
  proto, returns it.
  Args:
    checkpoint_dir: The directory of checkpoints.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.
  Returns:
    A CheckpointState if the state was available, None
    otherwise.
  Raises:
    ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
  """

对我来说,不清楚latest_filename到底是什么,或者我如何获得特定的检查点,而不是最新的检查点(如果多个检查点在同一目录中)

上述功能的第二行是:

coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
                                                     latest_filename)

line 43上定义为:

def _GetCheckpointFilename(save_dir, latest_filename):
  """Returns a filename for storing the CheckpointState.
  Args:
    save_dir: The directory for saving and restoring checkpoints.
    latest_filename: Name of the file in 'save_dir' that is used
      to store the CheckpointState.
  Returns:
    The path of the file that contains the CheckpointState proto.
  """
  if latest_filename is None:
    latest_filename = "checkpoint"
  return os.path.join(save_dir, latest_filename)

因此,它将类型缩小为字符串。

tf.estimator.Estimator API以以下形式保存检查点:

model.ckpt-<#####>.index
model.ckpt-<#####>.meta
model.ckpt-<#####>.data-<#####>-of-<#####>

因此,如果我正在使用这些检查点,则要指定哪个检查点,我会打电话给

tf.train.get_checkpoint_state(estimator_model_dir, latest_filename="model.ckpt-<#####>")

我尝试过:

CHECKPOINT_DIR = "path/to/checkpoints"
ckpt_num = "model.ckpt-#####"

file = ckpt_num
# file = ckpt_num + 'data-00000-of-00001'
# file = ckpt_num + 'index'
# file = ckpt_num + 'meta'

checkpoint = tf.train.get_checkpoint_state(CHECKPOINT_DIR, file)
checkpoint.model_checkpoint_path

所有这些都会引发错误,例如

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xa1 in position 0: invalid start byte

省略文件,为我提供了最新的检查点,这可能不是我想要的...

0 个答案:

没有答案