我用OpenAI训练了deepq模型。做完之后
saver = tf.train.Saver()
saver.save(tf.get_default_session(), 'my_deepq')
我得到了以下文件:
my_deepq.data-00000-of-00001
my_deepq.index
checkpoint
my_deepq.meta
然后我需要在两个不同的系统(C ++和python)中加载此模型以进行推断。
对于python部分,我尝试过:
import tensorflow as tf
tf.reset_default_graph()
imported_graph = tf.train.import_meta_graph('my_deepq.meta')
with tf.Session() as sess:
imported_graph.restore(sess, './my_deepq')
代码已经运行,但是我不确定模型在哪里加载以及如何进行推断。有人可以建议。
对于C ++方面,我将做类似的事情:
tensorflow::Session *my_sess;
tensorflow::Status status = tensorflow::NewSession(options, &my_sess);
tensorflow::GraphDef graph_def;
status = ReadBinaryProto(tensorflow::Env::Default(), model_path, &graph_def);
status = my_sess->Create(graph_def);
tensorflow::Status status = my_sess->Run({{"My_Input", input_tensor}}, {"My_Output"}, {}, &output_tensor);
此方法要求模型采用BinaryProto格式,但是我不确定如何在python中将模型保存在BinaryProto中。任何人都可以请指教。谢谢!