我试图将我的代码分成不同的模块,一个是模型训练的模块,另一个是分析模型中的权重。
使用
保存模型时save_path = saver.save(sess, "checkpoints5/text8.ckpt")
它制作4个文件,[' checkpoint',' text8.ckpt.data-00000-of-00001',' text8.ckpt.meta', ' text8.ckpt.index']
我尝试使用此代码
在单独的模块中恢复此功能train_graph = tf.Graph()
with train_graph.as_default():
saver = tf.train.Saver()
with tf.Session(graph=train_graph) as sess:
saver.restore(sess, tf.train.latest_checkpoint('MODEL4'))
embed_mat = sess.run(embedding)
但我收到此错误消息
ValueError Traceback (most recent call last)
<ipython-input-15-deaad9b67888> in <module>()
1 train_graph = tf.Graph()
2 with train_graph.as_default():
----> 3 saver = tf.train.Saver()
4
5
/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in __init__(self, var_list, reshape, sharded, max_to_keep, keep_checkpoint_every_n_hours, name, restore_sequentially, saver_def, builder, defer_build, allow_empty, write_version, pad_step_number, save_relative_paths, filename)
1309 time.time() + self._keep_checkpoint_every_n_hours * 3600)
1310 elif not defer_build:
-> 1311 self.build()
1312 if self.saver_def:
1313 self._check_saver_def()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in build(self)
1318 if context.executing_eagerly():
1319 raise RuntimeError("Use save/restore instead of build in eager mode.")
-> 1320 self._build(self._filename, build_save=True, build_restore=True)
1321
1322 def _build_eager(self, checkpoint_path, build_save, build_restore):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in _build(self, checkpoint_path, build_save, build_restore)
1343 return
1344 else:
-> 1345 raise ValueError("No variables to save")
1346 self._is_empty = False
1347
ValueError: No variables to save
在阅读了这个问题之后,似乎我需要重新定义训练模型时使用的所有变量。
有没有办法在不必重新定义所有内容的情况下访问权重?权重只是数字,当然必须有办法直接访问它们吗?
答案 0 :(得分:1)
如果只是访问检查点中的变量,请查看checkpoint_utils
库。它提供了三个有用的api函数:load_checkpoint
,list_variables
和load_variable
。我不确定是否有更好的方法,但你当然可以使用这些函数来提取检查点中所有变量的字典,如下所示:
import tensorflow as tf
ckpt = 'checkpoints5/text8.ckpt'
var_dict = {name: tf.train.load_checkpoint(ckpt).get_tensor(name)
for name, _ in tf.train.list_variables(ckpt)}
print(var_dict)
要加载预训练模型而不必重新定义所有变量,您需要的不仅仅是检查点。检查点只有变量,它没有如何恢复这些变量,即如何将它们映射到图形,没有实际图形(和适当的图)。对于这种情况,SavedModel
会更好。它可以保存模型MetaGraph
和所有变量。恢复已保存的模型时,您不必手动重新定义所有内容。以下代码仅使用simple_save
。
保存训练有素的模型:
import tensorflow as tf
x = tf.placeholder(tf.float32)
y_ = tf.reshape(x, [-1, 1])
y_ = tf.layers.dense(y_, units=1)
loss = tf.losses.mean_squared_error(labels=x, predictions=y_)
optimizer = tf.train.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for _ in range(10):
sess.run(train_op, feed_dict={x: range(10)})
# Let's check the bias here so that we can make sure
# the model we restored later on is indeed our trained model here.
d_b = sess.graph.get_tensor_by_name('dense/bias:0')
print(sess.run(d_b))
tf.saved_model.simple_save(sess, 'test', inputs={"x": x}, outputs={"y": y_})
要恢复已保存的模型:
import tensorflow as tf
with tf.Session(graph=tf.Graph()) as sess:
# A model saved by simple_save will be treated as a graph for inference / serving,
# i.e. uses the tag tag_constants.SERVING
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], 'test')
d_b = sess.graph.get_tensor_by_name('dense/bias:0')
print(sess.run(d_b))