在tensorflow中恢复图失败,因为没有要保存的变量

时间:2016-05-04 09:39:47

标签: python tensorflow restore

我知道堆栈和github等上有无数关于如何在Tensorflow中恢复训练模型的问题。我已阅读其中大​​部分内容(123)。

我有几乎完全相同的问题3然而我想尽可能以不同的方式解决它,因为我的训练和我的测试需要在从shell调用的单独脚本中,我不想添加确切的用于在测试脚本中定义图形的相同行,因此我不能使用tensorflow FLAGS和其他基于手动重新运行图形的答案。

我也不想sess.run每个变量并手动映射它们,因为我的图表非常大(使用带有参数input_map的import_graph_def)。

所以我运行一些图表并在特定的脚本中训练它。比如(但没有训练部分)

#Script 1
import tensorflow as tf
import cPickle as pickle

x=tf.Variable(42)
saver=tf.train.Saver()
sess=tf.Session()
#Saving the graph
graph_def=sess.graph_def
with open('graph.pkl','wb') as output:
  pickle.dump(graph_def,output,HIGHEST_PROTOCOL)


#Training the model
sess.run(tf.initialize_all_variables())
#Saving the variables
saver.save(sess,"pretrained_model.ckpt")

我现在已经保存了图形和变量,所以我应该能够从另一个脚本运行我的测试模型,即使我的图形中有额外的训练节点。

#Script 2
import tensorflow as tf
import cPickle as pickle

sess=tf.Session()
with open("graph.pkl','rb') as input:
  graph_def=pickle.load(input)



tf.import_graph_def(graph_def,name='persisted')

然后显然我想使用保护程序恢复变量,但我遇到了与3相同的问题,因为没有找到保存甚至创建保护程序的变量。 所以我不能写:

saver=tf.train.Saver()
saver.restore(sess,"pretrained_model.ckpt")

有没有办法绕过我想通过导入图形的那些限制它会在每个节点中恢复未初始化的变量,但似乎不是我真的需要像大多数给出的答案那样第二次重新运行它?

2 个答案:

答案 0 :(得分:6)

变量列表保存在Collection中,但未保存在GraphDef中。默认情况下,Saver使用ops.GraphKeys.VARIABLES集合中的列表(可通过tf.all_variables()访问),如果您从GraphDef还原而不是使用Python API构建模型,则该集合是空。您可以在tf.train.Saver(var_list=['MyVariable1:0', 'MyVariable2:0',...])手动指定变量列表。

另外,GraphDef代替MetaGraphDef您可以使用保存集合的var map, places, iw; var markers = []; var autocomplete; var defaultLatLng = new google.maps.LatLng(34.0983425, -118.3267434); var myLatlng = new google.maps.LatLng(37.783259, -122.402708); function initialize() { var myOptions = { zoom: 17, //center: myLatlng, mapTypeId: google.maps.MapTypeId.ROADMAP } if (navigator.geolocation) { browserSupportFlag = true; navigator.geolocation.getCurrentPosition(function(position) { initialLocation = new google.maps.LatLng(position.coords .latitude, position.coords.longitude); map.setCenter(initialLocation); map.setZoom(16); var marker = new google.maps.Marker({ position: initialLocation, title: "You are here!" }); marker.setMap(map); }, function() { handleNoGeolocation(browserSupportFlag); }); } map = new google.maps.Map(document.getElementById("map_canvas"), myOptions); places = new google.maps.places.PlacesService(map); google.maps.event.addListener(map, 'tilesloaded', tilesLoaded); autocomplete = new google.maps.places.Autocomplete(document.getElementById( 'autocomplete')); google.maps.event.addListener(autocomplete, 'place_changed', function() { showSelectedPlace(); }); } ,但最近添加了MetaGraphDef HowTo

答案 1 :(得分:-1)

根据我的知识和测试,您不能简单地将名称传递给tf.train.Saver对象。它必须是字典变量列表。

我还想从 graph_def 中读取模型,然后使用saver加载变量 - 但尝试仅在错误消息中生成:“要保存的变量不是变量”