如何优化重新训练的ssd_mobilenet_v2_coco以进行TensorFlowJS推理?

时间:2019-07-18 19:07:47

标签: javascript python tensorflow tensorflow.js tensorflowjs-converter

我试图通过宽松地遵循此tutorial来重新训练用于自定义对象检测的mobilenet_v2模型。我的最终目标是拥有一个可以查询的web_model,它将提供分数,classId和检测次数。最终导出的推断模型可在python环境中使用,但当前在转换为Web时会引发奇怪的错误。

感觉管线中缺少某个步骤,无法将推理图转换为Web。 model_main.py设置is_training=True似乎是一个问题,最终与最终的推理模型混为一谈。我似乎找不到关于如何从受过训练的模型中生成非训练模型的任何支持文档或教程。

我一直在使用tensorflow-gpu 1.13.1和model_main.py来重新训练object detection zoo提供的当前ssd_mobilenet_v2_coco模型。我也尝试过使用旧版train.py和tensorflow 1.14.0。

当需要将其转换为tfjs时,我同时使用了tensorflowjs 1.2.2.1和0.8.6,当试图在Web上运行最终结果时,都导致了相同的错误。

我还尝试过在冻结模型上执行中间图变换,然后再使用0.8.6对其进行转换。

训练模型:

python model_main.py --model_dir=output --pipeline_config_path=training\ssd_mobilenet_v2_coco.config -num_train_steps=200000

导出推理图:

python export_inference_graph.py --input_type=image_tensor --output_directory=output_inf --pipeline_config_path=training\ssd_mobilenet_v2_coco.config --trained_checkpoint_prefix=neg_32\model.ckpt-XXXX

使用tfjs 1.2.2.1进行转换:

tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --saved_model_tags=serve --signature_name=serving_default output_inf\saved_model output_inf\web_model

在浏览器中测试模型:

import * as tf from '@tensorflow/tfjs';

class Detector {
    async init() {      
        try {
            this.model = await tf.loadGraphModel('/web_model/model.json');
        } catch (err) {
            console.log(err);
        }
    }

    async detect(frame) {
        const { model } = this;

        const INPUT_TENSOR='image_tensor';
        const OUTPUT_TENSOR='num_detections'
        const zeros = tf.zeros([1, 300, 300, 3]);

        console.log("executing model");
        output = await model.executeAsync({[INPUT_TENSOR]: zeros}, OUTPUT_TENSOR);
        console.log(output);
    }
}

export default Detector;

中间转换:

def get_graph_def_from_file(graph_filepath):
    with ops.Graph().as_default():
        with tf.gfile.GFile(graph_filepath, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            return graph_def


graph_def = get_graph_def_from_file(file_name)

input_node=['image_tensor']
output_node=['num_detections,detection_scores,detection_boxes,detection_classes']
transforms = [
 'remove_nodes(op=Identity, op=CheckNumerics)',
 'fold_constants(ignore_errors=true)',
 'fold_batch_norms',
 'fold_old_batch_norms(ignore_errors=true)',
 'merge_duplicate_nodes',
 'strip_unused_nodes'
]

transformed_graph_def = graph_util.remove_training_nodes(graph_def, protected_nodes=output_node)

transformed_graph_def = TransformGraph(
        graph_def,
        input_node,
        output_node,
        transforms)

tf.train.write_graph(transformed_graph_def,
                         logdir=model_dir,
                         as_text=False,
                         name=out_name)

我希望最终的网络模型能够提供测试阵列的检测结果。但是,相反,当执行javascript代码时,tensorflowjs返回以下错误:

Uncaught (in promise) Error: Operands could not be broadcast together with shapes 1,150,150,32 and 0.
    at Ir (tfjs:2)
    at new bi (tfjs:2)
    at e.batchNormalization (tfjs:2)
    at kt.runKernel.$x (tfjs:2)
    at tfjs:2
    at t.scopedRun (tfjs:2)
    at t.runKernel (tfjs:2)
    at os (tfjs:2)
    at batchNorm (tfjs:2)
    at jv (tfjs:2)

然后尝试在TransformGraph中应用fold_old_batch_norms会产生此错误:

2019-07-07 22:16:11.717749: I tensorflow/tools/graph_transforms/transform_graph.cc:317] Applying fold_old_batch_norms
Traceback (most recent call last):
  File "xxx/optimize.py", line 154, in <module>
    optimize_graph(model_dir, output_frozen_fname, transforms, output_nodes, output_optimized_fname)
  File "xxx/optimize.py", line 135, in optimize_graph
    transforms)
  File "xxx\venv\lib\site-packages\tensorflow\tools\graph_transforms\__init__.py", line 51, in TransformGraph
    transforms_string, status)
  File "xxx\venv\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 528, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Beta input to batch norm has bad shape: [32]

0 个答案:

没有答案