我使用tensorflow中实现的randomforest估算器来预测文本是否为英文。我使用以下代码(train_input_fn函数返回功能和类标签)保存了我的模型(带有2k样本和2个类标签0/1(非英语/英语)的数据集):
model_path='test/'
TensorForestEstimator(params, model_dir='model/')
estimator.fit(input_fn=train_input_fn, max_steps=1)
运行上述代码后,graph.pbtxt和检查点将保存在模型文件夹中。现在我想在Android上使用它。我有两个问题:
作为第一步,我需要将图表和检查点冻结为.pb文件,以便在Android上使用它。我尝试了freeze_graph(我在这里使用了代码:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)。当我在我的模式中调用freeze_graph时,我收到以下错误,代码无法创建最终的.pb图:
文件" /Users/XXXXXXX/freeze_graph.py",第105行,在freeze_graph中 _ = tf.import_graph_def(input_graph_def,name ="") 文件" /anaconda/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/importer.py",第258行,在import_graph_def中 op_def = op_dict [node.op] KeyError:u' CountExtremelyRandomStats'
这就是我所说的freeze_graph:
def save_model_android():
checkpoint_state_name = "model.ckpt-1"
input_graph_name = "graph.pbtxt"
output_graph_name = "output_graph.pb"
checkpoint_path = os.path.join(model_path, checkpoint_state_name)
input_graph_path = os.path.join(model_path, input_graph_name)
input_saver_def_path = None
input_binary = False
output_node_names = "output"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join(model_path, output_graph_name)
clear_devices = True
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path,
output_node_names, restore_op_name,
filename_tensor_name, output_graph_path,
clear_devices, "")
我还尝试冻结" tf.contrib.learn.datasets.load_iris"中的虹膜数据集。我犯了同样的错误。所以我认为它与数据集无关。
先谢谢!
通过使用最新版本的tensorflow(0.12),问题得以解决。但是,现在,问题是我应该传递给output_node_names ???如何获取图表中的输出节点?
答案 0 :(得分:1)
Re(1)看起来你在tensorflow的构建上运行freeze_graph,它无法访问contrib操作。也许在调用freeze_graph之前尝试显式导入tensorforest?
Re(2)我不知道一个更简单的例子。
答案 1 :(得分:0)
CountExtremelyRandomStats是TensorForest的自定义操作之一,存在于tensorflow / contrib中。正如所指出的那样,TF在某些时候默认切换到包含contrib ops。我不认为在以前的版本中将contrib自定义操作包含在全局注册表中是一种简单的方法,因为TensorForest使用构建包含为数据文件的.so文件的方法。在运行时加载(在创建TensorForest时是标准的方法,但可能不再是这样)。因此,没有容易包含的python构建规则可以在C ++自定义操作中正确链接。您可以尝试将tensorflow / contrib / tensor_forest:ops_lib作为构建规则中的dep包含在内,但我认为它不会起作用。
在任何情况下,您都可以尝试安装tensorflow的每晚构建。替代方案包括修改tensorforest自定义操作的构建方式,这非常讨厌。