Tensorflow保存并恢复到新的会话/图形

时间:2019-04-29 22:10:53

标签: python-3.x tensorflow

假设您进行了张量流会话,使用权重训练网络并保存图形:

import tensorflow as tf
with tf.Session(graph=tf.Graph()).as_default() as sess:
    with sess.graph.as_default():
        some_var = tf.get_variable(name='foo', shape=(4))
        x = some_var + 1.0
        some_var.load([1, 2, 3, 4])
        # My understanding is that this saves the weights only:
        tf.train.Saver().save(sess, 'my/save/path')
        # My understanding is that this saves the graph structure (not sure if it saves the weights as well):
        graph_def = sess.graph.as_graph_def()

然后,您要将其加载到新鲜会话和新鲜图形中。这样做有多种原因,例如,如果旧会话中有大量后续图形添加,而您想清除它以节省内存。另一个原因是,如果您的学习方法动态生成网络拓扑以最适合训练数据,那么不同的数据集将具有不同的结构。在这种情况下,很难简单地重新运行网络生成代码并加载一组权重。

with tf.Session(graph=tf.Graph()).as_default() as sess:
    with sess.graph.as_default():
        tf.import_graph_def(graph_def, name='')
        tf.train.Saver().restore(sess, 'my/save/path')  # error here

但是,当您尝试加载时,此代码失败(尽管它抱怨保存 ):

ValueError: No variables to save

如何从graph_def和/或tf.train.Saver()文件加载新会话?

0 个答案:

没有答案