如何在检查点中列出某些变量?

时间:2016-08-14 16:30:47

标签: variables tensorflow state

我正在使用自动编码器。我的检查点包含网络的完整状态(即编码器,解码器,优化器等)。我想用编码来搞错。因此,我只需要在评估模式下使用网络的解码器部分。

如何从现有检查点中只读取一些特定变量,以便我可以在另一个模型中重用它们的值?

2 个答案:

答案 0 :(得分:2)

另一种方法是打印所有检查点张量(或者只有一个,如果指定)及其内容:

from tensorflow.python.tools import inspect_checkpoint as inch
inch.print_tensors_in_checkpoint_file('path/to/ckpt', '', True)
"""
Args:
  file_name: Name of the checkpoint file.
  tensor_name: Name of the tensor in the checkpoint file to print.
  all_tensors: Boolean indicating whether to print all tensors.
"""

它将始终打印张量的内容。

而且,虽然我们在这里,但这里是如何使用checkpoint_utils.py(由前面的答案建议):

from tensorflow.contrib.framework.python.framework import checkpoint_utils

var_list = checkpoint_utils.list_variables('./')
for v in var_list:
    print(v)

答案 1 :(得分:1)

您可以使用.ppt查看.ckpt文件中的已保存变量

import tensorflow as tf

variables_in_checkpoint = tf.train.list_variables('path.ckpt')

print("Variables found in checkpoint file",variables_in_checkpoint)