从Android上的Open Images数据集中获取Pretrained Inception v3模型

时间:2017-01-13 13:28:41

标签: android tensorflow

我尝试了一段时间让预先训练的模型在android上工作。问题是,我只获得了预训练网的ckpt和meta文件。在我看来,我需要Android应用程序的.pb。所以我尝试将给定的文件转换为.pb文件。

因此我尝试了freeze_graph.py但没有成功。所以我使用了https://github.com/openimages/dataset/blob/master/tools/classify.py中的示例代码并对其进行了修改以存储pb。加载后的文件

if not os.path.exists(FLAGS.checkpoint):
  tf.logging.fatal(
      'Checkpoint %s does not exist. Have you download it? See tools/download_data.sh',
      FLAGS.checkpoint)
   g = tf.Graph()
with g.as_default():
  input_image = tf.placeholder(tf.string)
   processed_image = PreprocessImage(input_image)

  with slim.arg_scope(inception.inception_v3_arg_scope()):
    logits, end_points = inception.inception_v3(
        processed_image, num_classes=FLAGS.num_classes, is_training=False)

    predictions = end_points['multi_predictions'] = tf.nn.sigmoid(
       logits, name='multi_predictions')
  init_op = control_flow_ops.group(tf.global_variables_initializer(),
                             tf.global_variables_initializer(),
                             data_flow_ops.initialize_all_tables())
  saver = tf_saver.Saver()
  sess = tf.Session()
  saver.restore(sess, FLAGS.checkpoint)

  outpt_filename = 'output_graph.pb'
  #output_graph_def = sess.graph.as_graph_def()
  output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), ["multi_predictions"])
  with gfile.FastGFile(outpt_filename, 'wb') as f:
      f.write(output_graph_def.SerializeToString())

现在我的问题是我有.pb文件,但我不知道输入节点名称是什么,我不确定multi_predictions是否是正确的输出名称。在示例Android应用程序中,我必须指定两者。 Android应用程序崩溃了:

tensorflow_inference_jni.cc:138 Could not create Tensorflow Graph: Invalid argument: No OpKernel was registered to support Op 'DecodeJpeg' with these attrs. 

我不知道尝试修复.pb问题是否还有更多问题。或者,如果有人知道在我的情况下将ckpt和meta文件移植到.pd文件的更好方法,或者知道带有输入和输出名称的最终文件的来源,请给我一个提示来完成此任务。

由于

1 个答案:

答案 0 :(得分:2)

您需要使用optimize_for_inference.py脚本去除图表中未使用的节点。 Android不支持“decodeJpeg” - 像素值应直接输入。 ClassifierActivity.java提供了有关用于初始v3的特定节点的更多详细信息。