Tensorflow:有没有办法加载预训练模型而不必重新定义所有变量?

时间:2018-05-06 00:55:51

标签: tensorflow

我试图将我的代码分成不同的模块,一个是模型训练的模块,另一个是分析模型中的权重。

使用

保存模型时
save_path = saver.save(sess, "checkpoints5/text8.ckpt")

它制作4个文件,[' checkpoint',' text8.ckpt.data-00000-of-00001',' text8.ckpt.meta', ' text8.ckpt.index']

我尝试使用此代码

在单独的模块中恢复此功能
train_graph = tf.Graph()
with train_graph.as_default():
    saver = tf.train.Saver()


with tf.Session(graph=train_graph) as sess:
    saver.restore(sess, tf.train.latest_checkpoint('MODEL4'))
    embed_mat = sess.run(embedding)

但我收到此错误消息

ValueError                                Traceback (most recent call last)
<ipython-input-15-deaad9b67888> in <module>()
      1 train_graph = tf.Graph()
      2 with train_graph.as_default():
----> 3     saver = tf.train.Saver()
      4 
      5 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in __init__(self, var_list, reshape, sharded, max_to_keep, keep_checkpoint_every_n_hours, name, restore_sequentially, saver_def, builder, defer_build, allow_empty, write_version, pad_step_number, save_relative_paths, filename)
   1309           time.time() + self._keep_checkpoint_every_n_hours * 3600)
   1310     elif not defer_build:
-> 1311       self.build()
   1312     if self.saver_def:
   1313       self._check_saver_def()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in build(self)
   1318     if context.executing_eagerly():
   1319       raise RuntimeError("Use save/restore instead of build in eager mode.")
-> 1320     self._build(self._filename, build_save=True, build_restore=True)
   1321 
   1322   def _build_eager(self, checkpoint_path, build_save, build_restore):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in _build(self, checkpoint_path, build_save, build_restore)
   1343           return
   1344         else:
-> 1345           raise ValueError("No variables to save")
   1346       self._is_empty = False
   1347 

ValueError: No variables to save

在阅读了这个问题之后,似乎我需要重新定义训练模型时使用的所有变量。

有没有办法在不必重新定义所有内容的情况下访问权重?权重只是数字,当然必须有办法直接访问它们吗?

1 个答案:

答案 0 :(得分:1)

如果只是访问检查点中的变量,请查看checkpoint_utils库。它提供了三个有用的api函数:load_checkpointlist_variablesload_variable。我不确定是否有更好的方法,但你当然可以使用这些函数来提取检查点中所有变量的字典,如下所示:

import tensorflow as tf

ckpt = 'checkpoints5/text8.ckpt'
var_dict = {name: tf.train.load_checkpoint(ckpt).get_tensor(name)
            for name, _ in tf.train.list_variables(ckpt)}
print(var_dict)

要加载预训练模型而不必重新定义所有变量,您需要的不仅仅是检查点。检查点只有变量,它没有如何恢复这些变量,即如何将它们映射到图形,没有实际图形(和适当的图)。对于这种情况,SavedModel会更好。它可以保存模型MetaGraph和所有变量。恢复已保存的模型时,您不必手动重新定义所有内容。以下代码仅使用simple_save

保存训练有素的模型:

import tensorflow as tf

x = tf.placeholder(tf.float32)
y_ = tf.reshape(x, [-1, 1])
y_ = tf.layers.dense(y_, units=1)
loss = tf.losses.mean_squared_error(labels=x, predictions=y_)
optimizer = tf.train.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    for _ in range(10):
        sess.run(train_op, feed_dict={x: range(10)})
    # Let's check the bias here so that we can make sure
    # the model we restored later on is indeed our trained model here.
    d_b = sess.graph.get_tensor_by_name('dense/bias:0')
    print(sess.run(d_b))
    tf.saved_model.simple_save(sess, 'test', inputs={"x": x}, outputs={"y": y_})

要恢复已保存的模型:

import tensorflow as tf

with tf.Session(graph=tf.Graph()) as sess:
    # A model saved by simple_save will be treated as a graph for inference / serving,
    # i.e. uses the tag tag_constants.SERVING
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], 'test')
    d_b = sess.graph.get_tensor_by_name('dense/bias:0')
    print(sess.run(d_b))