我想在Android应用中使用我的Tensorflow算法。 Tensorflow Android示例首先下载包含模型定义和权重的GraphDef(在* .pb文件中)。现在这应该来自我的Scikit Flow算法(Tensorflow的一部分)。
乍一看,你只需说classifier.save('model /'),但保存到该文件夹的文件不是* .ckpt,* .def,当然不是* .pb。相反,你必须处理* .pbtxt和检查点(没有结束)文件。
很长一段时间以来我一直困在那里。这是一个导出内容的代码示例:#imports
import tensorflow as tf
import tensorflow.contrib.learn as skflow
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics
#skflow example
iris = datasets.load_iris()
feature_columns = learn.infer_real_valued_columns_from_input(iris.data)
classifier = learn.LinearClassifier(n_classes=3, feature_columns=feature_columns,model_dir="modeltest")
classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
iris_predictions = list(classifier.predict(iris.data, as_iterable=True))
score = metrics.accuracy_score(iris.target, iris_predictions)
print("Accuracy: %f" % score)
您获得的文件是:
我发现的许多可能的解决方法都需要在变量中使用GraphDef(不知道如何使用Scikit Flow)。或者使用Scikit Flow似乎不需要Tensorflow会话。
答案 0 :(得分:2)
要保存为pb文件,您需要从构造的图形中提取graph_def。你可以这样做 -
from tensorflow.python.framework import tensor_shape, graph_util
from tensorflow.python.platform import gfile
sess = tf.Session()
final_tensor_name = 'results:0' #Replace final_tensor_name with name of the final tensor in your graph
#########Build your graph and train########
## Your tensorflow code to build the graph
###########################################
outpt_filename = 'output_graph.pb'
output_graph_def = sess.graph.as_graph_def()
with gfile.FastGFile(outpt_filename, 'wb') as f:
f.write(output_graph_def.SerializeToString())
如果要将训练过的变量转换为常量(以避免使用ckpt文件加载权重),可以使用:
output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [final_tensor_name])
希望这有帮助!