冻结mobilenet_v2(freeze_graph.py)

时间:2018-08-11 06:13:07

标签: tensorflow

我使用苗条的train_image_classifier训练了mobilenet_V2。我的命令是:

python3 ~/models/research/slim/train_image_classifier.py \
--model_name="mobilenet_v2" \
--learning_rate=0.045 * 2 \
--preprocessing_name="inception_v2" \
--label_smoothing=0.1 \
--moving_average_decay=0.9999 \
--batch_size=16 \
--num_clones=2 \
--learning_rate_decay_factor=0.98 \
--num_epochs_per_decay=2.5 / 2 \
--train_dir=[...]/tensorflow_logs/mobilenet_v2 \
--dataset_dir=[...]/imagenet_data \
--dataset_name='imagenet' \
--train_image_size=229

它在mobilenet_v2 /中创建了以下文件(不是完整列表):

checkpoint
graph.pbtxt
[...]
model.ckpt-697555.data-00000-of-00001
model.ckpt-697555.index
model.ckpt-697555.meta

我可以使用检查点进行推断。

我现在正努力将检查点变量转换为带有冻结图的Const ops。当我尝试时:

python3 -m tensorflow.python.tools.freeze_graph \
  --input_graph [...]/tensorflow_logs/mobilenet_v2/graph.pbtxt \
  --input_checkpoint [...]/tensorflow_logs/mobilenet_v2/model.ckpt-697555 \
  --input_binary false \
  --output_graph /mnt/sda1/tensorflow_logs/mobilenet_v2/mobilenet_v2_frozen.pb \
  --output_node_names MobilenetV2/Predictions/Reshape_1

我得到:AssertionError: MobilenetV2/Predictions/Reshape_1 is not in graph 但是,[print(n.name) for n in tf.get_default_graph().as_graph_def().node]的输出包含:

MobilenetV2/Logits/Squeeze
MobilenetV2/Logits/output
MobilenetV2/Predictions/Reshape/shape
MobilenetV2/Predictions/Reshape
MobilenetV2/Predictions/Softmax
MobilenetV2/Predictions/Shape
MobilenetV2/Predictions/Reshape_1

引起我的困惑。

上次我能够冻结我的图形时,我使用了input_saved_model_dir选项来加载保存的模型[我相信现在是首选方法?]代替了input_graph和input_checkpoint。不幸的是,train_image_classifier.py脚本不会创建一个,只会创建检查点。

任何想法/评论表示赞赏!

0 个答案:

没有答案