我使用tf.estimator
API(更具体地说是使用tf.estimator.train_and_evaluate
)在TensorFlow中训练了一个模型。我有培训的输出目录。如何从中加载我的模型然后使用它?
我尝试通过加载最新的tf.train.Saver
文件并还原会话来使用ckpt
类。但是,然后要调用sess.run()
,我需要知道图的输出节点的名称是什么,以便将其传递给fetches
参数。名称是什么/如何访问此输出节点?有没有更好的方法来加载和使用经过训练的模型?
请注意,我已经训练了模型并将其保存在ckpt
文件中,所以请不要建议我使用simple_save
函数。
答案 0 :(得分:0)
您可以使用tf.train.list_variables('ckpt file')
检查保存在检查点文件中的所有变量。您可以使用tf.train.init_from_checkpoint()
从文件开始启动。您只能通过使用贴图来使用检查点中存在的变量
variables_from_ckpt = [i[0] for i in tf.train.list_variables('ckpt file')]
assigment_map = {variable.op.name: variable for variable in tf.global_variables() if variable.op.name in variables_from_ckpt}
答案 1 :(得分:0)
(回答我自己的问题),我意识到最简单的方法是使用tf.estimator
API。通过初始化一个从模型目录开始的估计器,可以调用estimator.predict
并传递正确的args(predict_fn
)并立即获得预测。不需要以任何方式处理图形变量。