如何在不恢复图形的情况下从张量流检查点提取权重和其他变量值?

时间:2018-11-20 10:43:44

标签: python tensorflow checkpoint

提供了一个检查点文件,但是没有生成网络的元图或代码,我想提取检查点文件中变量的存储值。

因此,在不恢复图形的情况下,如何提取存储在thr检查点中的值。我可以将所有内容从检查点转换为numpy数组或类似的字典。

1 个答案:

答案 0 :(得分:1)

找到解决方案:

reader = tf.train.NewCheckpointReader("/path/to/checkpoint")
shapes_dict = reader.get_variable_to_shape_map()  # use it to get the variable names
extracted_values = reader.get_tensor(shapes_dict.keys()[0])
# array([[ 0.       , -1.8053141],
#        [-1.5647348,  0.       ]], dtype=float32)

在API r1.12的当前文档中,tf.train.NewCheckpointReadernot really documented。 但是您可以在源代码here中看到用法示例。