我目前正在尝试从vizdoom平台上的另一个项目https://github.com/pathak22/noreward-rl中检索训练有素的TF模型。
我已成功设法使用以下方法在我的新项目中导入模型:
session = tf.Session()
print("Loading model from: ", model_savefile)
saver = tf.train.import_meta_graph(model_savefile + '.meta')
saver.restore(session, model_savefile)
但是,我无权访问生成此保存文件的代码(我认为这是通过OpenAI Gym,但不确定),因此我不知道我应该使用哪些名称来输入它。
你知道怎么做吗?
提前多多感谢
答案 0 :(得分:1)
导入MetaGraph
会将操作添加到默认图表中。
在图表中打印所有操作:
print(tf.get_default_graph().get_operations())
打印类似:
[<tf.Operation 'Placeholder' type=Placeholder>, <tf.Operation 'mul/y' type=Const>, <tf.Operation 'mul' type=Mul>]
仅打印占位符:
print([op for op in tf.get_default_graph().get_operations() if op.type == 'Placeholder'])
打印类似:
[<tf.Operation 'Placeholder' type=Placeholder>]