正在加载冻结的tf模型-没有占位符张量值

时间:2019-05-29 05:23:25

标签: python tensorflow

我试图冻结Tensorflow图并还原它,但是当我尝试运行预测时,出现错误:

You must feed a value for placeholder tensor 'DQNetwork/actions' with dtype float and shape [?,10] 

我的恢复代码是:

sess = tf.Session()
graph = tf.get_default_graph()

with graph.as_default():
    with sess.as_default():
        GRAPH_PB_PATH = "./frozentensorflowModel.pb"
        with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
            graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

        x_tensor = graph.get_tensor_by_name("DQNetwork/inputs:0")
        op_to_restore = graph.get_tensor_by_name("DQNetwork/actions:0")

        new_state(cards.copy())
        state = game_state.state

        feed_dict={x_tensor: state.reshape((1, *state.shape))}
        opt = []
        opt = sess.run(op_to_restore, feed_dict) # Error throws
        predictions = np.argmax(opt, 1)

我这样定义了DQNetwork输入:

DQNetwork.inputs = tf.placeholder(tf.float32, [None, state_size], name="inputs") 
DQNetwork.actions = tf.placeholder(tf.float32, [None, action_size], name="actions")

更多信息:

>>>op_to_restore
<tf.Tensor 'DQNetwork/actions:0' shape=(?, 10) dtype=float32>
>>>op_to_restore.op
<tf.Operation 'DQNetwork/actions' type=Placeholder>

培训专线:

results = sess.run(DQNetwork.output, feed_dict = {DQNetwork.inputs: input_batch})

1 个答案:

答案 0 :(得分:0)

这可能会帮助您:

sess = tf.Session()
graph = tf.get_default_graph()

with graph.as_default():
    with sess.as_default():
        GRAPH_PB_PATH = "./frozentensorflowModel.pb"
        with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
            graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

        x_tensor = graph.get_tensor_by_name("DQNetwork/inputs:0")
        op_to_restore = graph.get_operation_by_name("DQNetwork/actions")

        new_state(cards.copy())
        state = game_state.state

        feed_dict={x_tensor: state.reshape((1, *state.shape))}
        opt = []
        opt = sess.run(op_to_restore, feed_dict) # Error throws
        predictions = np.argmax(opt, 1)

这就是我的建议。

我明白了:

feed_dict={x_tensor: state.reshape((1, *state.shape))}

尝试使用sess.run(op_to_restore, feed_dict)代替op_to_restore.eval(feeddict)