从TensorFlow Android Camera Demo重新训练Inception5h模型

时间:2017-02-02 13:45:52

标签: android python tensorflow

TensorFlow Android Camera Demo使用Inception5h model进行实时图像识别,可提供卓越的性能。由于我没有成功再培训Inception5h我已经和InceptionV3 model一起去了,但它在图像识别方面并不那么活泼。所以我回到开始尝试重新训练(或转学)Inception5h模型。我尝试修改retrain.py,但它只是为v3模型编写的。 5h模型不包含“pool_3 / _reshape:0”,“DecodeJpeg / contents:0”或“ResizeBilinear:0”张量。还有其他差异。

我是机器学习和TensorFlow的新手,所以我非常感谢我必须做的明确步骤。

谢谢!

2 个答案:

答案 0 :(得分:2)

看起来retrain.py脚本和tutorial刚刚更新,可与mobilenet架构配合使用。

所以这解决了问题的第一部分,它实际上并不是初始的5h,但它在移动设备上运行良好,准确性比inception5h好得多。

要实际让它在Android示例中运行,您仍需要更新these settings

我认为你应该只能复制the settings determined for the mobilenet you choose, from the retrain script而你可能没事。

如果你想使用另一个没有module Test : sig type t val apply : (int option -> 'a) -> t -> 'a val make_t : unit -> t val ( @ ) : (int option -> 'a) -> t -> 'a end = struct type t = int option let make_t () = Some 42 let apply f x = f x let ( @ ) = apply end let do_work : 'a option -> unit = function | Some x -> Printf.printf "Some\n" | None -> Printf.printf "None\n" let () = let open Test in do_work @ make_t () 设置的网络,那么我能想到的最简单的方法就是用TensorBoard探索图形。

因此,如果你真的想要使用inh 5h,你可以下载并解压缩它:

retrain.py

然后从Tensorflow for Poets: 2 codelab repo抓取这个简单的脚本,将图形curl -O https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip unzip -d inception5h inception5h.zip 文件转换为张量板可以使用的文件:

.pb

然后在你的graph.pb上运行它:

curl -O https://raw.githubusercontent.com/googlecodelabs/tensorflow-for-poets-2/master/scripts/graph_pb2tb.py

open it in tensorboard

mkdir tb_graph
python graph_pb2tb.py tb/inception5h inception5h/tensorflow_inception_graph.pb 

然后在图表中查找并找到填写自己的model_info词典所需的节点名称可能相对简单。

我认为这是您要设置为tensorboard --logdir tb_graph 的节点:

TensorBoard screenshot of inception 5h with avgpool0/reshape highlighted

答案 1 :(得分:1)

在retrain.py脚本结束时,您可以注意到以下几行:

output_graph_def = graph_util.convert_variables_to_constants(
  sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
with gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
   f.write(output_graph_def.SerializeToString())

这里所有变量都作为常量保存在协议缓冲区(pb)文件中,该文件是二进制文件('wb')。您还应该在文本文件中保存模型类的名称。然后在android文档中提到,你应该将这两个文件保存在tensorflow的android路径中名为“assets”的文件夹中。然后,应该进行一些修改以加载inception-v3模型,您可以在此处看到:https://github.com/tensorflow/tensorflow/issues/1269 我希望这个能帮上忙!