图像分类图模型做出错误的预测

时间:2019-12-25 16:23:24

标签: tensorflow tensorflow.js tensorflowjs-converter mobilenet

我正在使用make_image_classifier python脚本在新的一组图像上重新训练mobilenetv2。我的最终目标是在浏览器的tfjs中进行预测。

这正是我在做什么:

第1步:重新训练模型

make_image_classifier \
  --image_dir input_data \
  --tfhub_module https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4 \
  --image_size 224 \
  --saved_model_dir ./trained_model \
  --labels_output_file class_labels.txt \
  --tflite_output_file new_mobile_model.tflite

第2步:使用tensorflowjs_converter将tf保存的模型转换为图形模型

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_format=tfjs_graph_model \
    --signature_name=serving_default \
    --saved_model_tags=serve \
    trained_model/ \
    web_model/

第3步:在浏览器中加载新模型,对图像输入进行预处理,并要求模型进行预测

const model =  tf.loadGraphModel('model.json').then(function(m){
var img = document.getElementById("img");
var processed=preprocessImage(img, "mobilenet")
    window.prediction=m.predict(processed)
        window.prediction.print();
    })
})

function preprocessImage(image,modelName){
    let tensor=tf.browser.fromPixels(image)
    .resizeNearestNeighbor([224,224])
    .toFloat();
    console.log('tensor pro', tensor);
    if(modelName==undefined)
    {
        return tensor.expandDims();
    }
    if(modelName=="mobilenet")
    {
        let offset=tf.scalar(127.5);
        console.log('offset',offset);
        return tensor.sub(offset)
        .div(offset)
        .expandDims();
    }
    else
    {
        throw new Error("Unknown Model error");
    }
}

我得到无效的结果。我检查了初始模型所作的预测,这些预测是正确的,因此我在想的是转换未正确进行,或者我没有以与初始脚本相同的方式预处理图像。

帮助。

P.S:运行转换器时,我收到以下消息。不知道它是否与我正在经历的事情直接相关。

  

tensorflow / core / graph / graph_constructor.cc:750节点'StatefulPartitionedCall'具有71个输出,但是_output_shapes属性指定605个输出的形状。输出形状可能不正确。

1 个答案:

答案 0 :(得分:0)

make_image_classifier创建一个指定给tensorflow lite的saved_model。如果您想将mobilenet转换为tensorflow.js,则已在此answer中给出了要使用的命令。

您将需要使用make_image_classifier,而不是使用retrain.py,它可以被以下内容取代

curl -LO https://github.com/tensorflow/hub/raw/master/examples/image_retraining/retrain.py