TF检查点中包含什么?例如,估算器存储一个单独的文件,其中包含GraphDef
原型,您基本上可以执行一个tf.import_graph_def()
,然后创建一个tf.train.Saver()
并将检查点恢复到图形中。现在,如果您有另一个GraphDef
描述了一个完全不同的图,而该图恰好共享完全相同的变量名称以及匹配的变量维,那么您是否可以将检查点加载到该图中?换句话说,它只是变量名到值的映射,还是假设在加载过程中要检查的其他图形?如果您尝试将检查点加载到原始图子集的图中(即张量尺寸和名称匹配,但缺少一些名称)怎么办?
答案 0 :(得分:1)
人们什么时候开始阅读文档(?): https://www.tensorflow.org/mobile/prepare_models
这些是不同的概念。只要形状匹配,就可以仅加载权重。如果错过比赛,您将得到:
从检查点恢复失败。这很可能是由于 当前图和来自检查点的图之间不匹配。 请确保您没有根据以下内容更改期望的图表 检查点。
但是,您可以调整图形完全不同的非平凡情况:
import tensorflow as tf
import numpy as np
test_data = np.arange(4).reshape(1, 2, 2, 1)
# a simple graph and everything is fine
input = tf.placeholder(dtype=tf.float32, shape=[1, 2, 2, 1])
output = tf.layers.conv2d(input, 3, kernel_size=1, name='test', use_bias=False)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(output, {input: test_data}))
saver = tf.train.Saver()
save_path = saver.save(sess, "/tmp/model.ckpt")
print(tf.trainable_variables())
# reset previous elements
tf.reset_default_graph()
# a new graph
input = tf.placeholder(dtype=tf.float32, shape=[1, 2, 2, 1])
# and wait: this is complete different but same name and shape
W = tf.get_variable('test/kernel', shape=[1, 1, 1, 3])
# but the graph has different operations
output = input + W
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess, "/tmp/model.ckpt")
print(sess.run(output, {input: test_data}))
就我而言:
# 1st version (original graph)
[[[[-0. -0. -0. ]
[-0.08429337 -1.0156475 -0.42691123]]
[[-0.16858673 -2.031295 -0.85382247]
[-0.2528801 -3.0469427 -1.2807337 ]]]]
# 2nd version (altered graph)
[[[[-0.08429337 -1.0156475 -0.42691123]
[ 0.91570663 -0.01564753 0.57308877]]
[[ 1.9157066 0.98435247 1.5730888 ]
[ 2.9157066 1.9843525 2.5730886 ]]]]