使用保存的模型进行张量流预测

时间:2018-04-18 14:31:59

标签: tensorflow

我使用此代码恢复我的模型,但我不知道如何在恢复后预测,我可以使用哪种功能?我是tensorflow的初学者,我不知道将保存哪些参数或功能。

在元模型中:

sess = tf.Session()
saver = tf.train.import_meta_graph("/home/MachineLearning/model.ckpt.meta")
saver.restore(sess,tf.train.latest_checkpoint('./'))
print("Model restored with success ")
x_predict,y_predict= load_svmlight_file('/MachineLearning/to_predict.csv')
x_predict = x_valid.toarray()
sess.run([] ,feed_dict ) #i don't know how to use predict function

结果如下:

$python predict.py
Model restored with success 
Traceback (most recent call last):
  File "predict.py", line 23, in <module>
    sess.run([] ,feed_dict )
NameError: name 'feed_dict' is not defined

1 个答案:

答案 0 :(得分:2)

你几乎就在那里。 Tensorflow只是一个数学库。您的图表是具有相关依赖关系的数学运算的集合(例如,图表,DAG)。

加载图形和关联变量(权重)后,加载了所有定义。现在您需要让tensorflow计算图中的某个值。它可以计算很多值,你想要的值通常被命名为logits(神经网络输出层的典型名称)。但请注意,它可以被命名为任何东西(特别是如果这不是神经网络模型),您需要了解模型。您可能还想计算一个名为accuracy的操作,该操作被定义为计算特定批输入的准确性(再次取决于您的模型)。

请注意,您需要提供tensorflow以及执行这些计算所需的任何内容。通常会在placeholder中传入您的数据(并且在培训过程中placeholder为您的标签提供您不需要预测的标签,因为您将要求的任何操作都不需要张量来计算在上面)。

但是您需要获得对这些不同操作的引用(logitsaccuracy)和占位符(x是典型名称)。由于您从磁盘加载了图形,因此您没有引用(请注意,加载模型的另一种方法是重新运行构建模型的代码,这使您可以轻松访问所需的引用)。

为了获得正确的引用,您可以按名称查找它们。以下是如何获得所有操作的列表:

List of tensor names in graph in Tensorflow

然后按名称获取特定的OP(操作):

How to get a tensorflow op by name?

所以你会有这样的事情:

logits = tf.get_default_graph().get_operation_by_name("logits:0")
x = tf.get_default_graph().get_operation_by_name("x:0")
accuracy = tf.get_default_graph().get_operation_by_name("accuracy:0")

请注意,:0是添加到tensorflow中所有名称的索引,以避免重复名称。现在您拥有了所需的所有参考,您可以使用sess.run执行特定计算,提供输入数据,以及您想要计算的OP:

sess.run([logits, accuracy], feed_dict={x:your_input_data_in_numpy_format})

这些元素的名称会因您的实施而有所不同,我使用了最常见的名称。如果他们没有给出漂亮的名字,那么很难识别它们,并且您需要查看生成图表的原始代码。实际上,如果它们没有正确命名,那么按名称查找它们是非常痛苦的,只是重新运行生成原始图形的代码而不是导入元图形可能更好。请注意,saver.restore仅恢复实际数据,import_meta_graph是可选项,可以通过编程方式重新构建图表来替换。