假设您进行了张量流会话,使用权重训练网络并保存图形:
import tensorflow as tf
with tf.Session(graph=tf.Graph()).as_default() as sess:
with sess.graph.as_default():
some_var = tf.get_variable(name='foo', shape=(4))
x = some_var + 1.0
some_var.load([1, 2, 3, 4])
# My understanding is that this saves the weights only:
tf.train.Saver().save(sess, 'my/save/path')
# My understanding is that this saves the graph structure (not sure if it saves the weights as well):
graph_def = sess.graph.as_graph_def()
然后,您要将其加载到新鲜会话和新鲜图形中。这样做有多种原因,例如,如果旧会话中有大量后续图形添加,而您想清除它以节省内存。另一个原因是,如果您的学习方法动态生成网络拓扑以最适合训练数据,那么不同的数据集将具有不同的结构。在这种情况下,很难简单地重新运行网络生成代码并加载一组权重。
with tf.Session(graph=tf.Graph()).as_default() as sess:
with sess.graph.as_default():
tf.import_graph_def(graph_def, name='')
tf.train.Saver().restore(sess, 'my/save/path') # error here
但是,当您尝试加载时,此代码失败(尽管它抱怨保存 ):
ValueError: No variables to save
如何从graph_def和/或tf.train.Saver()文件加载新会话?