在Python中加载和使用经过训练的TensorFlow模型

时间:2019-02-02 02:41:43

标签: python tensorflow

我使用tf.estimator API(更具体地说是使用tf.estimator.train_and_evaluate)在TensorFlow中训练了一个模型。我有培训的输出目录。如何从中加载我的模型然后使用它?

我尝试通过加载最新的tf.train.Saver文件并还原会话来使用ckpt类。但是,然后要调用sess.run(),我需要知道图的输出节点的名称是什么,以便将其传递给fetches参数。名称是什么/如何访问此输出节点?有没有更好的方法来加载和使用经过训练的模型?

请注意,我已经训练了模型并将其保存在ckpt文件中,所以请不要建议我使用simple_save函数。

2 个答案:

答案 0 :(得分:0)

您可以使用tf.train.list_variables('ckpt file')检查保存在检查点文件中的所有变量。您可以使用tf.train.init_from_checkpoint()从文件开始启动。您只能通过使用贴图来使用检查点中存在的变量

 variables_from_ckpt = [i[0] for i in tf.train.list_variables('ckpt file')]  
 assigment_map = {variable.op.name: variable for variable in tf.global_variables() if variable.op.name in variables_from_ckpt}

答案 1 :(得分:0)

(回答我自己的问题),我意识到最简单的方法是使用tf.estimator API。通过初始化一个从模型目录开始的估计器,可以调用estimator.predict并传递正确的args(predict_fn)并立即获得预测。不需要以任何方式处理图形变量。