Tensorflow Scikit Flow获取Android版GraphDef(保存* .pb文件)

时间:2016-09-25 00:51:06

标签: scikit-learn tensorflow tensorflow-serving

我想在Android应用中使用我的Tensorflow算法。 Tensorflow Android示例首先下载包含模型定义和权重的GraphDef(在* .pb文件中)。现在这应该来自我的Scikit Flow算法(Tensorflow的一部分)。

乍一看,你只需说classifier.save('model /'),但保存到该文件夹​​的文件不是* .ckpt,* .def,当然不是* .pb。相反,你必须处理* .pbtxt和检查点(没有结束)文件。

很长一段时间以来我一直困在那里。这是一个导出内容的代码示例:

#imports
import tensorflow as tf
import tensorflow.contrib.learn as skflow
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics

#skflow example
iris = datasets.load_iris()
feature_columns = learn.infer_real_valued_columns_from_input(iris.data)
classifier = learn.LinearClassifier(n_classes=3, feature_columns=feature_columns,model_dir="modeltest")
classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
iris_predictions = list(classifier.predict(iris.data, as_iterable=True))
score = metrics.accuracy_score(iris.target, iris_predictions)
print("Accuracy: %f" % score)

您获得的文件是:

  • 检查点
  • graph.pbtxt
  • model.ckpt-1.meta
  • model.ckpt-1-00000-的-00001
  • model.ckpt-200.meta
  • model.ckpt-200-00000-的-00001

我发现的许多可能的解决方法都需要在变量中使用GraphDef(不知道如何使用Scikit Flow)。或者使用Scikit Flow似乎不需要Tensorflow会话。

1 个答案:

答案 0 :(得分:2)

要保存为pb文件,您需要从构造的图形中提取graph_def。你可以这样做 -

from tensorflow.python.framework import tensor_shape, graph_util
from tensorflow.python.platform import gfile
sess = tf.Session()
final_tensor_name = 'results:0'     #Replace final_tensor_name with name of the final tensor in your graph
#########Build your graph and train########
## Your tensorflow code to build the graph
###########################################

outpt_filename = 'output_graph.pb'
output_graph_def = sess.graph.as_graph_def()
with gfile.FastGFile(outpt_filename, 'wb') as f:
  f.write(output_graph_def.SerializeToString())

如果要将训练过的变量转换为常量(以避免使用ckpt文件加载权重),可以使用:

output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [final_tensor_name])

希望这有帮助!