我已分别训练了两个模型,我想加载它们的变量并对它们求平均值。但是tf.get_default graph()
出错了这是我的代码结构(我知道这是错误的,但如何正确编写?)
sess = tf.session()
saver_one = tf.train.import_meta_graph('./model1.ckpt.meta')
saver_one.restore(sess,'./model1.ckpt')
graph_one = tf.get_default_graph()
wc1 = graph_one.get_tensor_by_name('wc1:0')
……
saver_two = tf.train.import_meta_graph('./model2.ckpt.meta')
saver_two.restore(sess,'./model2.ckpt')
graph_two = tf.get_default_graph()
wc1_two = graph_two.get_tensor_by_name('wc1:0')
……
错误来自:
Traceback (most recent call last):
File "/home/dan/Documents/deep-prior-master/src/ESB_ICVL_TEST_ALL.py", line 143, in <module>
saver_two.restore(sess,'./cache/cnn_shallow/model2.ckpt')
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1548, in restore
{self.saver_def.filename_tensor_name: save_path})
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 789, in run
run_metadata_ptr)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 997, in _run
feed_dict_string, options, run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1132, in _do_run
target_list, options, run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1152, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [27] rhs shape= [9]
[[Node: save/Assign_6 = Assign[T=DT_FLOAT, _class=["loc:@outb"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/gpu:0"](outb, save/RestoreV2_6/_1)]]
非常感谢您给我任何建议。 =(^。^)=
答案 0 :(得分:1)
您正在尝试覆盖图形并且它会遇到不匹配(某些尺寸不匹配)。将它们分开可能更好。
graph_one = tf.Graph()
with graph_one.as_default():
session_one = tf.Session()
with session_one.as_default():
saver_one = tf.train.import_meta_graph('./model1.ckpt.meta')
wc1_one_value = session_one.run([graph_one.get_tensor_by_name('wc1:0')])
# Similar for graph_two
...
print (wc1_one_value + wc1_two_value) / 2 # Or whatever you want
要将它们分配回会话,构建图形,然后执行tf.assign operations。
with graph_one.as_default(), session_one.as_default():
session_one.run([tf.assign(<variable>, (wc1_one_value + wc1_two_value) / 2 )])
要获取变量,您可以使用get_trainable_variables
或使用reuse=True
再次定义变量。
然后再次导出模型。