恢复子图变量失败并显示'无法解释feed_dict键'

时间:2018-03-19 00:42:09

标签: python tensorflow

上下文是我试图通过首先训练单个单元编码器/解码器然后扩展来逐步增长rnn自动编码器。我想加载前面单元格的参数。

这里的代码是一个最小的代码,我正在调查如何执行此操作,但它失败了:

TypeError: Cannot interpret feed_dict key as Tensor: The name 'save_1/Const:0' refers to a Tensor which does not exist. The operation, 'save_1/Const', does not exist in the graph.

我搜索过但没有找到任何内容,this threadthis thread不是同一个问题。

MVCE

import tensorflow as tf
import numpy as np

with tf.Session(graph=tf.Graph()) as sess:
    cell1 = tf.nn.rnn_cell.LSTMCell(1,name='lstm_cell1')
    cell = tf.nn.rnn_cell.MultiRNNCell([cell1])

    inputs = tf.random_normal((5,10,1))
    rnn1 = tf.nn.dynamic_rnn(cell,inputs,dtype=tf.float32)
    vars0 = tf.trainable_variables()
    saver = tf.train.Saver(vars0,max_to_keep=1)
    sess.run(tf.initialize_all_variables())
    saver.save(sess,'./save0')

    vars0_val = sess.run(vars0)
# creating a new graph/session because it is not given that it'll be in the same session.  
with tf.Session(graph=tf.Graph()) as sess:
    cell1 = tf.nn.rnn_cell.LSTMCell(1,name='lstm_cell1')
    #one extra cell
    cell2 = tf.nn.rnn_cell.LSTMCell(1,name='lstm_cell2')
    cell = tf.nn.rnn_cell.MultiRNNCell([cell1,cell2])

    inputs = tf.random_normal((5,10,1))
    rnn1 = tf.nn.dynamic_rnn(cell,inputs,dtype=tf.float32)  
    sess.run(tf.initialize_all_variables())

    # new saver with first cell variables
    saver = tf.train.Saver(vars0,max_to_keep=1)

    # fails
    saver.restore(sess,'./save0')

    # Should be the same
    vars0_val1 = sess.run(vars0)
    assert np.all(vars0_val1 = vars0_val)

1 个答案:

答案 0 :(得分:0)

错误来自这条线,

saver = tf.train.Saver(vars0,max_to_keep=1)

如果是第二次会话。 vars0指的是上一个图形中存在的实际张量对象(不是当前图形)。 Saver的var_list需要一组实际的张量(不是字符串,我认为这些字符串足够好)。 要使其工作,应使用当前图形中的相应张量初始化第二个Saver对象。 像,

vars0_names = [v.name for v in vars0]
load_vars = [sess.graph.get_tensor_by_name(n) for n in vars0_names]
saver = tf.train.Saver(load_vars,max_to_keep=1)