我希望看到TensorFlow检查点中保存的变量及其值。如何找到TensorFlow检查点中保存的变量名?
我使用了statsmodels
here。但是,TensorFlow的文档中没有给出它。还有其他办法吗?
答案 0 :(得分:40)
使用示例:
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")
# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')
# List contents of v0 tensor.
# Example output: tensor_name: v0 [[[[ 9.27958265e-02 7.40226209e-02 4.52989563e-02 3.15700471e-02
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')
# List contents of v1 tensor.
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')
更新: all_tensors
参数自Tensorflow 0.12.0-rc0以来已添加到print_tensors_in_checkpoint_file
,因此您可能需要添加all_tensors=False
或all_tensors=True
如果需要。
替代方法:
from tensorflow.python import pywrap_tensorflow
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
print(reader.get_tensor(key)) # Remove this is you want to print only variable names
希望它有所帮助。
答案 1 :(得分:16)
您可以使用inspect_checkpoint.py
工具。
因此,例如,如果您将检查点存储在当前目录中,那么您可以按如下方式打印变量及其值
import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
latest_ckp = tf.train.latest_checkpoint('./')
print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='')
答案 2 :(得分:10)
更多细节。
如果您的模型是使用V2格式保存的,例如,如果我们在目录/my/dir/
中有以下文件
model-10000.data-00000-of-00001
model-10000.index
model-10000.meta
然后file_name
参数应该只是前缀,即
print_tensors_in_checkpoint_file(file_name='/my/dir/model_10000', tensor_name='', all_tensors=True)
请参阅https://github.com/tensorflow/tensorflow/issues/7696进行讨论。
答案 3 :(得分:0)
将更多参数详细信息添加到print_tensors_in_checkpoint_file
file_name
:不是物理文件,只是文件名的前缀
如果未提供tensor_name
,则打印张量名称和形状
在检查点文件中。如果提供了tensor_name
,则打印张量的内容。(inspect_checkpoint.py)
如果all_tensor_names
为True
,则打印所有张量名称
如果all_tensor
为'True',则打印所有张量名称和相应的内容。
NB all_tensor
和all_tensor_names
将覆盖tensor_name
答案 4 :(得分:0)
上述答案的更新
对于最新的Tensorflow版本(已在TF 1.13+上验证),更简洁的方法如下
ckpt_reader = tf.train.load_checkpoint(ckpt_dir_or_file)
value = ckpt_reader.get_tensor(name_of_the_tensor)
name_of_the_tensor
应该对应于变量名(您要检查的值)。要在检查点中获取变量名称和形状的列表,可以通过
vars_list = tf.train.list_variables(ckpt_dir_or_file)
答案 5 :(得分:0)
要添加旁注,print_tensors_in_checkpoint_file 无法打印大张量中的所有值(某些值将被省略为“...”)。要查看所有值,您可以使用如下代码
import tensorflow as tf
tf.enable_eager_execution()
from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader('/dir/to/ckpt/model.ckpt-81230')
t = reader.get_tensor('YOUR_TENSOR_NAME_HERE')
# t is an numpy array, and you can check it like print(list(t))