冻结Android的tensor_forest图表

时间:2017-08-24 08:47:23

标签: android tensorflow deep-learning tensor

我在tensorflow中构建了一个简单的随机森林模型,并希望冻结&为Android优化它。 我使用以下函数构建tesnor_forest估算器:

def build_estimator(_model_dir, _num_classes, _num_features, _num_trees, _max_nodes):
      params = tensor_forest.ForestHParams(
      num_classes=_num_classes, num_features=_num_features,
      num_trees=_num_trees, max_nodes=_max_nodes, min_split_samples=3)

    graph_builder_class = tensor_forest.RandomForestGraphs
    return random_forest.TensorForestEstimator(
      params, graph_builder_class=graph_builder_class,
      model_dir=_model_dir)

此函数将文本模型存储到指定模型目录中的graph.pbtxt文件中。

然后我用它训练它:

est = build_estimator(output_model_dir, 3,np.size(features_eval,1), 5,6)
train_X = features_eval.astype(dtype=np.float32)
train_Y = labels_y.astype(dtype=np.float32)
est.fit(x=train_X, y=train_Y, batch_size=np.size(features_eval,0))

(在这个简单的例子中:树的数量= 5,max_nodes = 6)

现在我要冻结模型,所以我称之为函数:

def save_model_android(model_path):
checkpoint_state_name = "model.ckpt-1"
input_graph_name = "graph.pbtxt"
output_graph_name = "freezed_model.pb"
checkpoint_path = os.path.join(model_path, checkpoint_state_name)

input_graph_path = os.path.join(model_path, input_graph_name)
input_saver_def_path = None
input_binary = False
output_node_names = "output"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join(model_path, output_graph_name)
clear_devices = True

freeze_graph(input_graph_path, input_saver_def_path,
                          input_binary, checkpoint_path,
                          output_node_names, restore_op_name,
                          filename_tensor_name, output_graph_path,
                          clear_devices, "")

并且在生成的freezed_model.pb文件中,我只得到1个op,它是输出节点。 在控制台中,当调用freeze_graph函数时,我收到以下消息:

Converted 0 variables to const ops.
1 ops in the final graph.

有人知道为什么在调用freeze_graph时只导出一个节点?

我正在使用Tensorflow版本1.2.1和cuda支持,从linux上的源代码安装

2 个答案:

答案 0 :(得分:0)

我遇到了同样的问题,但是fo convert的代码可以顺利地将另一个ckpt模型转移到pb,当转移deeplabV3的CKPT模型时出现错误,我不知道。

转移代码很烂:

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile

def freeze_graph(ckpt, output_graph):
output_node_names = "logits/biases"
saver = tf.train.import_meta_graph(ckpt+'.meta', clear_devices=False)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()

with tf.Session() as sess:
saver.restore(sess, ckpt)
output_graph_def = graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=input_graph_def,
output_node_names=output_node_names.split(',')
)
with tf.gfile.GFile(output_graph, 'wb') as fw:
fw.write(output_graph_def.SerializeToString())
print ('{} ops in the final graph.'.format(len(output_graph_def.node)))

ckpt = './6/model.ckpt'
pb = './6/modelxxxxxx.pb'

if __name__ == '__main__':
freeze_graph(ckpt, pb)

答案 1 :(得分:0)

问题解决了,我需要在grapy中添加另一个输出节点 就像 : https://github.com/GeorgeSeif/Semantic-Segmentation-Suite/issues/63