如何查找检查点中保存的变量名称和值?

时间:2016-07-06 07:12:21

标签: tensorflow

我希望看到TensorFlow检查点中保存的变量及其值。如何找到TensorFlow检查点中保存的变量名?

我使用了statsmodels here。但是,TensorFlow的文档中没有给出它。还有其他办法吗?

6 个答案:

答案 0 :(得分:40)

使用示例:

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")

# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')

# List contents of v0 tensor.
# Example output: tensor_name:  v0 [[[[  9.27958265e-02   7.40226209e-02   4.52989563e-02   3.15700471e-02
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')

# List contents of v1 tensor.
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')

更新: all_tensors参数自Tensorflow 0.12.0-rc0以来已添加到print_tensors_in_checkpoint_file,因此您可能需要添加all_tensors=Falseall_tensors=True如果需要。

替代方法:

from tensorflow.python import pywrap_tensorflow
import os

checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()

for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key)) # Remove this is you want to print only variable names

希望它有所帮助。

答案 1 :(得分:16)

您可以使用inspect_checkpoint.py工具。

因此,例如,如果您将检查点存储在当前目录中,那么您可以按如下方式打印变量及其值

import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file


latest_ckp = tf.train.latest_checkpoint('./')
print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='')

答案 2 :(得分:10)

更多细节。

如果您的模型是使用V2格式保存的,例如,如果我们在目录/my/dir/中有以下文件

model-10000.data-00000-of-00001
model-10000.index
model-10000.meta

然后file_name参数应该只是前缀,即

print_tensors_in_checkpoint_file(file_name='/my/dir/model_10000', tensor_name='', all_tensors=True)

请参阅https://github.com/tensorflow/tensorflow/issues/7696进行讨论。

答案 3 :(得分:0)

将更多参数详细信息添加到print_tensors_in_checkpoint_file

file_name:不是物理文件,只是文件名的前缀

如果未提供tensor_name,则打印张量名称和形状 在检查点文件中。如果提供了tensor_name,则打印张量的内容。(inspect_checkpoint.py

如果all_tensor_namesTrue,则打印所有张量名称

如果all_tensor为'True',则打印所有张量名称和相应的内容。

NB all_tensorall_tensor_names将覆盖tensor_name

答案 4 :(得分:0)

上述答案的更新

对于最新的Tensorflow版本(已在TF 1.13+上验证),更简洁的方法如下

ckpt_reader = tf.train.load_checkpoint(ckpt_dir_or_file)
value = ckpt_reader.get_tensor(name_of_the_tensor)

name_of_the_tensor应该对应于变量名(您要检查的值)。要在检查点中获取变量名称和形状的列表,可以通过

进行检查
vars_list = tf.train.list_variables(ckpt_dir_or_file)

答案 5 :(得分:0)

要添加旁注,print_tensors_in_checkpoint_file 无法打印大张量中的所有值(某些值将被省略为“...”)。要查看所有值,您可以使用如下代码

import tensorflow as tf
tf.enable_eager_execution()
from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader('/dir/to/ckpt/model.ckpt-81230')
t = reader.get_tensor('YOUR_TENSOR_NAME_HERE')
# t is an numpy array, and you can check it like print(list(t))