函数tf.train.get_checkpoint_state
(第246行)具有以下函数签名(在文件checkpoint_management中定义)
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
"""Returns CheckpointState proto from the "checkpoint" file.
If the "checkpoint" file contains a valid CheckpointState
proto, returns it.
Args:
checkpoint_dir: The directory of checkpoints.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
Returns:
A CheckpointState if the state was available, None
otherwise.
Raises:
ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
"""
对我来说,不清楚latest_filename
到底是什么,或者我如何获得特定的检查点,而不是最新的检查点(如果多个检查点在同一目录中)
上述功能的第二行是:
coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
latest_filename)
在line 43上定义为:
def _GetCheckpointFilename(save_dir, latest_filename):
"""Returns a filename for storing the CheckpointState.
Args:
save_dir: The directory for saving and restoring checkpoints.
latest_filename: Name of the file in 'save_dir' that is used
to store the CheckpointState.
Returns:
The path of the file that contains the CheckpointState proto.
"""
if latest_filename is None:
latest_filename = "checkpoint"
return os.path.join(save_dir, latest_filename)
因此,它将类型缩小为字符串。
tf.estimator.Estimator
API以以下形式保存检查点:
model.ckpt-<#####>.index
model.ckpt-<#####>.meta
model.ckpt-<#####>.data-<#####>-of-<#####>
因此,如果我正在使用这些检查点,则要指定哪个检查点,我会打电话给
tf.train.get_checkpoint_state(estimator_model_dir, latest_filename="model.ckpt-<#####>")
我尝试过:
CHECKPOINT_DIR = "path/to/checkpoints"
ckpt_num = "model.ckpt-#####"
file = ckpt_num
# file = ckpt_num + 'data-00000-of-00001'
# file = ckpt_num + 'index'
# file = ckpt_num + 'meta'
checkpoint = tf.train.get_checkpoint_state(CHECKPOINT_DIR, file)
checkpoint.model_checkpoint_path
所有这些都会引发错误,例如
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xa1 in position 0: invalid start byte
省略文件,为我提供了最新的检查点,这可能不是我想要的...