TensorFlow中已恢复变量的列表

时间:2016-11-21 11:42:06

标签: tensorflow

在从ckpt文件恢复变量值之前,我需要在恢复变量的文件中创建这些变量。但是在这个新文件中,可能还有一些其他变量不在ckpt文件中。是否可以只打印一个已恢复的变量列表(在这种情况下,tf.all_variables不起作用)?

3 个答案:

答案 0 :(得分:3)

您可以使用inspect_checkpoint.py工具列出要恢复的检查点中的变量。

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

# List ALL tensors.
print_tensors_in_checkpoint_file(file_name='./model.ckpt', tensor_name='')

# List contents of a specific tensor.
print_tensors_in_checkpoint_file(file_name='./model.ckpt', tensor_name='conv1_w')

另一种方法:

from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader('./model.ckpt')
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key)) # Remove this is you want to print only variable names

列出当前图表中的所有全局变量:

for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    print(v)

希望它有所帮助。

答案 1 :(得分:1)

如果你想要一个可以保存的变量列表,我相信你可以使用它:

tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS)

答案 2 :(得分:1)

使用 print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors)

参数数量:     file_name:检查点文件的名称。     tensor_name:要打印的检查点文件中张量的名称。     all_tensors:指示是否打印所有张量的布尔值。