如何在tensorflow中恢复占位符?

时间:2017-12-26 10:30:26

标签: python tensorflow

我在tensorflow中保存了一个模型,我想恢复它以供进一步使用,但是我收到了一个错误。代码在某种程度上如下:

import tensorflow as tf
def input_func(dim):
    input_ = tf.placeholder(tf.float32,[1,dim])
    return input_
def fully_connect(input_,out_dimension):
    out=tf.layers.dense(input_, out_dimension,\
        kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=False))
    return tf.reduce_sum(out)
def train(real_input, input_dim, out_dimension):
    input_ = input_func(input_dim)
    output = fully_connect(input_, out_dimension)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch in range(10):
            sess.run(output, {input_:real_input})

        tf.add_to_collection('input_',input_)
        tf.add_to_collection('output',output)
        tf.train.Saver().save(sess,'./save/expression') 
dim=3
out_dimension=2
real_input=[[1,2,3]]
with tf.Graph().as_default():
    train(real_input, dim, out_dimension)

现在建立并保存模型。

稍后恢复模型我使用了以下代码:

with tf.Session() as sess:
    loader = tf.train.import_meta_graph('./save/expression.ckpt.meta')
    loader.restore(sess, './save/expression.ckpt')
    input_=tf.get_collection('input_')
    print(input_)
    output=tf.get_collection('output')
    print(sess.run(output, {input_:[[4,5,6]]}))

但是我遇到了一个错误:

INFO:tensorflow:Restoring parameters from ./save/expression.ckpt
[]

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-98-6cfbdc96438e> in <module>()
      5     print(input_)
      6     output=tf.get_collection('output')
----> 7     print(sess.run(output, {input_:[[4,5,6]]}))

TypeError: unhashable type: 'list'

似乎占位符input_未保存!

有人可以帮助我吗?

1 个答案:

答案 0 :(得分:3)

您必须恢复占位符并为其提供适当的值。理想情况下,您应该在创建占位符时为其命名。由于您尚未命名,因此必须从图表中找到该名称。 恢复模型后,打印出图表中节点的名称,将首先打印占位符。 您可以使用

执行此操作
with tf.Session() as sess:
    loader = tf.train.import_meta_graph('./save/expression.ckpt.meta')
    loader.restore(sess, './save/expression.ckpt')
    graph = tf.get_default_graph()
    for op in graph.get_operations():
        print(op.name)

我猜输入占位符将被赋予默认名称“Placeholder”。 找到其名称后,您必须恢复该张量并为其提供值。 如果名称为Placeholder,则可以使用

进行恢复

graph.get_tensor_by_name('Placeholder:0')

您应该以相同的方式找到输出节点的名称。它应该类似fully_connected_1/matmul...,让我们假设名称为outputNodeName。 然后您可以将图表作为

运行
with tf.Session() as sess:
    loader = tf.train.import_meta_graph('./save/expression.ckpt.meta')
    loader.restore(sess, './save/expression.ckpt')
    graph = tf.get_default_graph()
    input_= graph.get_tensor_by_name('Placeholder:0')
    output=tf.get_collection('outputNodeName:0')
    print(sess.run(output, {input_:[[4,5,6]]}))