我试图冻结Tensorflow图并还原它,但是当我尝试运行预测时,出现错误:
You must feed a value for placeholder tensor 'DQNetwork/actions' with dtype float and shape [?,10]
我的恢复代码是:
sess = tf.Session()
graph = tf.get_default_graph()
with graph.as_default():
with sess.as_default():
GRAPH_PB_PATH = "./frozentensorflowModel.pb"
with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
x_tensor = graph.get_tensor_by_name("DQNetwork/inputs:0")
op_to_restore = graph.get_tensor_by_name("DQNetwork/actions:0")
new_state(cards.copy())
state = game_state.state
feed_dict={x_tensor: state.reshape((1, *state.shape))}
opt = []
opt = sess.run(op_to_restore, feed_dict) # Error throws
predictions = np.argmax(opt, 1)
我这样定义了DQNetwork
输入:
DQNetwork.inputs = tf.placeholder(tf.float32, [None, state_size], name="inputs")
DQNetwork.actions = tf.placeholder(tf.float32, [None, action_size], name="actions")
更多信息:
>>>op_to_restore
<tf.Tensor 'DQNetwork/actions:0' shape=(?, 10) dtype=float32>
>>>op_to_restore.op
<tf.Operation 'DQNetwork/actions' type=Placeholder>
培训专线:
results = sess.run(DQNetwork.output, feed_dict = {DQNetwork.inputs: input_batch})
答案 0 :(得分:0)
这可能会帮助您:
sess = tf.Session()
graph = tf.get_default_graph()
with graph.as_default():
with sess.as_default():
GRAPH_PB_PATH = "./frozentensorflowModel.pb"
with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
x_tensor = graph.get_tensor_by_name("DQNetwork/inputs:0")
op_to_restore = graph.get_operation_by_name("DQNetwork/actions")
new_state(cards.copy())
state = game_state.state
feed_dict={x_tensor: state.reshape((1, *state.shape))}
opt = []
opt = sess.run(op_to_restore, feed_dict) # Error throws
predictions = np.argmax(opt, 1)
这就是我的建议。
我明白了:
feed_dict={x_tensor: state.reshape((1, *state.shape))}
尝试使用sess.run(op_to_restore, feed_dict)
代替op_to_restore.eval(feeddict)