了解TensorFlow检查点加载吗?

时间:2018-08-17 17:04:44

标签: python tensorflow

TF检查点中包含什么?例如,估算器存储一个单独的文件,其中包含GraphDef原型,您基本上可以执行一个tf.import_graph_def(),然后创建一个tf.train.Saver()并将检查点恢复到图形中。现在,如果您有另一个GraphDef描述了一个完全不同的图,而该图恰好共享完全相同的变量名称以及匹配的变量维,那么您是否可以将检查点加载到该图中?换句话说,它只是变量名到值的映射,还是假设在加载过程中要检查的其他图形?如果您尝试将检查点加载到原始图子集的图中(即张量尺寸和名称匹配,但缺少一些名称)怎么办?

1 个答案:

答案 0 :(得分:1)

人们什么时候开始阅读文档(?): https://www.tensorflow.org/mobile/prepare_models

这些是不同的概念。只要形状匹配,就可以仅加载权重。如果错过比赛,您将得到:

  

从检查点恢复失败。这很可能是由于   当前图和来自检查点的图之间不匹配。   请确保您没有根据以下内容更改期望的图表   检查点。

但是,您可以调整图形完全不同的非平凡情况:

import tensorflow as tf
import numpy as np

test_data = np.arange(4).reshape(1, 2, 2, 1)

# a simple graph and everything is fine
input = tf.placeholder(dtype=tf.float32, shape=[1, 2, 2, 1])
output = tf.layers.conv2d(input, 3, kernel_size=1, name='test', use_bias=False)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(sess.run(output, {input: test_data}))
  saver = tf.train.Saver()
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print(tf.trainable_variables())

# reset previous elements
tf.reset_default_graph()

# a new graph
input = tf.placeholder(dtype=tf.float32, shape=[1, 2, 2, 1])
# and wait: this is complete different but same name and shape
W = tf.get_variable('test/kernel', shape=[1, 1, 1, 3])
# but the graph has different operations
output = input + W

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver()
  saver.restore(sess, "/tmp/model.ckpt")
  print(sess.run(output, {input: test_data}))

就我而言:

# 1st version (original graph)
[[[[-0.         -0.         -0.        ]
   [-0.08429337 -1.0156475  -0.42691123]]

  [[-0.16858673 -2.031295   -0.85382247]
   [-0.2528801  -3.0469427  -1.2807337 ]]]]
# 2nd version (altered graph)
[[[[-0.08429337 -1.0156475  -0.42691123]
   [ 0.91570663 -0.01564753  0.57308877]]

  [[ 1.9157066   0.98435247  1.5730888 ]
   [ 2.9157066   1.9843525   2.5730886 ]]]]