保存和使用TensorForestEstimator for Android

时间:2016-11-28 16:49:54

标签: android tensorflow skflow

我使用tensorflow中实现的randomforest估算器来预测文本是否为英文。我使用以下代码(train_input_fn函数返回功能和类标签)保存了我的模型(带有2k样本和2个类标签0/1(非英语/英语)的数据集):

model_path='test/'
TensorForestEstimator(params, model_dir='model/')
estimator.fit(input_fn=train_input_fn, max_steps=1)

运行上述代码后,graph.pbtxt和检查点将保存在模型文件夹中。现在我想在Android上使用它。我有两个问题:

  1. 作为第一步,我需要将图表和检查点冻结为.pb文件,以便在Android上使用它。我尝试了freeze_graph(我在这里使用了代码:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)。当我在我的模式中调用freeze_graph时,我收到以下错误,代码无法创建最终的.pb图:

    文件" /Users/XXXXXXX/freeze_graph.py",第105行,在freeze_graph中     _ = tf.import_graph_def(input_graph_def,name ="")   文件" /anaconda/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/importer.py",第258行,在import_graph_def中     op_def = op_dict [node.op] KeyError:u' CountExtremelyRandomStats'

  2. 这就是我所说的freeze_graph:

    def save_model_android():
        checkpoint_state_name = "model.ckpt-1"
        input_graph_name = "graph.pbtxt"
        output_graph_name = "output_graph.pb"
        checkpoint_path = os.path.join(model_path, checkpoint_state_name)
    
        input_graph_path = os.path.join(model_path, input_graph_name)
        input_saver_def_path = None
        input_binary = False
        output_node_names = "output"
        restore_op_name = "save/restore_all"
        filename_tensor_name = "save/Const:0"
        output_graph_path = os.path.join(model_path, output_graph_name)
        clear_devices = True
    
        freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                                  input_binary, checkpoint_path,
                                  output_node_names, restore_op_name,
                                  filename_tensor_name, output_graph_path,
                                  clear_devices, "")
    

    我还尝试冻结" tf.contrib.learn.datasets.load_iris"中的虹膜数据集。我犯了同样的错误。所以我认为它与数据集无关。

    1. 作为第二步,我需要使用手机上的.pb文件来预测文本。我发现谷歌的相机演示示例,它包含很多代码。我想知道是否有一步一步的教程如何通过传递特征向量并获取类标签来在Android上使用Tensorflow模型。
    2. 先谢谢!

      更新

      通过使用最新版本的tensorflow(0.12),问题得以解决。但是,现在,问题是我应该传递给output_node_names ???如何获取图表中的输出节点?

2 个答案:

答案 0 :(得分:1)

Re(1)看起来你在tensorflow的构建上运行freeze_graph,它无法访问contrib操作。也许在调用freeze_graph之前尝试显式导入tensorforest?

Re(2)我不知道一个更简单的例子。

答案 1 :(得分:0)

CountExtremelyRandomStats是TensorForest的自定义操作之一,存在于tensorflow / contrib中。正如所指出的那样,TF在某些时候默认切换到包含contrib ops。我不认为在以前的版本中将contrib自定义操作包含在全局注册表中是一种简单的方法,因为TensorForest使用构建包含为数据文件的.so文件的方法。在运行时加载(在创建TensorForest时是标准的方法,但可能不再是这样)。因此,没有容易包含的python构建规则可以在C ++自定义操作中正确链接。您可以尝试将tensorflow / contrib / tensor_forest:ops_lib作为构建规则中的dep包含在内,但我认为它不会起作用。

在任何情况下,您都可以尝试安装tensorflow的每晚构建。替代方案包括修改tensorforest自定义操作的构建方式,这非常讨厌。