Tensorflow,使用import_graph_def()加载模型错误

时间:2017-10-06 12:06:59

标签: tensorflow

我尝试通过import_graph_def()读取RNN网络并进行推理。 但我无法使用tf.trainable_variables()来获取任何变量。

在以下代码中,tf.trainable_variables()返回[](没有任何内容的列表) 此外,当我使用saver = tf.train.Saver()时,tensorflow报告“没有要保存的变量”

def eval_on_test(graph_path):
batch_size = 80
train_begin = 0
train_end = 3000
with tf.Graph().as_default() as graph:
    with open(graph_path, 'rb') as f:
        tf_graph = tf.GraphDef()
        print("Loading graph_def from {}".format(graph_path))
        tf_graph.ParseFromString(f.read())

        return_elements = tf.import_graph_def(tf_graph, name="", return_elements=['input_x:0', 'output_y:0', 'pred:0', 'loss:0'])
        X = return_elements[0]
        Y = return_elements[1]
        pred = return_elements[2]
        loss = return_elements[3]

    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True

    print("graph loaded, start testing")
    with tf.Session(config=tf_config) as sess:

        init_op = sess.graph.get_operation_by_name('init')
        sess.run(init_op)
        print(tf.trainable_variables())
        batch_index,train_x,train_y=get_train_data(batch_size,time_step,train_begin,train_end)
        for batch in range(len(batch_index)-1):
            loss_ = sess.run(loss, feed_dict={X:train_x[batch_index[batch]:batch_index[batch+1]],Y:train_y[batch_index[batch]:batch_index[batch+1]]})
            print(batch, loss_)

任何帮助都将不胜感激。

1 个答案:

答案 0 :(得分:0)

cd breakpad git clone https://chromium.googlesource.com/linux-syscall-support src/third_party/lss 只会恢复图表,但不会恢复import_graph_def等集合,这就是GLOBAL_VARIABLES无法在图表中找到任何变量的原因,以解决此问题,您可以尝试Saver,这也将恢复所有集合。