我正在使用自动编码器。我的检查点包含网络的完整状态(即编码器,解码器,优化器等)。我想用编码来搞错。因此,我只需要在评估模式下使用网络的解码器部分。
如何从现有检查点中只读取一些特定变量,以便我可以在另一个模型中重用它们的值?
答案 0 :(得分:2)
另一种方法是打印所有检查点张量(或者只有一个,如果指定)及其内容:
from tensorflow.python.tools import inspect_checkpoint as inch
inch.print_tensors_in_checkpoint_file('path/to/ckpt', '', True)
"""
Args:
file_name: Name of the checkpoint file.
tensor_name: Name of the tensor in the checkpoint file to print.
all_tensors: Boolean indicating whether to print all tensors.
"""
它将始终打印张量的内容。
而且,虽然我们在这里,但这里是如何使用checkpoint_utils.py
(由前面的答案建议):
from tensorflow.contrib.framework.python.framework import checkpoint_utils
var_list = checkpoint_utils.list_variables('./')
for v in var_list:
print(v)
答案 1 :(得分:1)
您可以使用.ppt查看.ckpt文件中的已保存变量
import tensorflow as tf
variables_in_checkpoint = tf.train.list_variables('path.ckpt')
print("Variables found in checkpoint file",variables_in_checkpoint)