上下文是我试图通过首先训练单个单元编码器/解码器然后扩展来逐步增长rnn自动编码器。我想加载前面单元格的参数。
这里的代码是一个最小的代码,我正在调查如何执行此操作,但它失败了:
TypeError: Cannot interpret feed_dict key as Tensor: The name 'save_1/Const:0' refers to a Tensor which does not exist. The operation, 'save_1/Const', does not exist in the graph.
我搜索过但没有找到任何内容,this thread和this thread不是同一个问题。
import tensorflow as tf
import numpy as np
with tf.Session(graph=tf.Graph()) as sess:
cell1 = tf.nn.rnn_cell.LSTMCell(1,name='lstm_cell1')
cell = tf.nn.rnn_cell.MultiRNNCell([cell1])
inputs = tf.random_normal((5,10,1))
rnn1 = tf.nn.dynamic_rnn(cell,inputs,dtype=tf.float32)
vars0 = tf.trainable_variables()
saver = tf.train.Saver(vars0,max_to_keep=1)
sess.run(tf.initialize_all_variables())
saver.save(sess,'./save0')
vars0_val = sess.run(vars0)
# creating a new graph/session because it is not given that it'll be in the same session.
with tf.Session(graph=tf.Graph()) as sess:
cell1 = tf.nn.rnn_cell.LSTMCell(1,name='lstm_cell1')
#one extra cell
cell2 = tf.nn.rnn_cell.LSTMCell(1,name='lstm_cell2')
cell = tf.nn.rnn_cell.MultiRNNCell([cell1,cell2])
inputs = tf.random_normal((5,10,1))
rnn1 = tf.nn.dynamic_rnn(cell,inputs,dtype=tf.float32)
sess.run(tf.initialize_all_variables())
# new saver with first cell variables
saver = tf.train.Saver(vars0,max_to_keep=1)
# fails
saver.restore(sess,'./save0')
# Should be the same
vars0_val1 = sess.run(vars0)
assert np.all(vars0_val1 = vars0_val)
答案 0 :(得分:0)
错误来自这条线,
saver = tf.train.Saver(vars0,max_to_keep=1)
如果是第二次会话。 vars0
指的是上一个图形中存在的实际张量对象(不是当前图形)。 Saver的var_list
需要一组实际的张量(不是字符串,我认为这些字符串足够好)。
要使其工作,应使用当前图形中的相应张量初始化第二个Saver对象。
像,
vars0_names = [v.name for v in vars0]
load_vars = [sess.graph.get_tensor_by_name(n) for n in vars0_names]
saver = tf.train.Saver(load_vars,max_to_keep=1)