使用TensorFlow.js的图像识别/标签

时间:2019-02-25 13:27:45

标签: javascript machine-learning image-recognition tensorflow.js

我们正在使用TensorFlow.js创建和训练自定义模型。我们使用tf.browser.fromPixels()函数将图像转换为张量。我们要创建一个自定义并训练一个自定义模型。为了实现此目的,我们创建了两个不同的网页(即:第一个页面用于创建自定义模式,并使用图像及其相关标签对其进行训练,第二个页面用于加载训练后的模型,并且通过使用预训练的模型预测图片以获取与此图片的关联标签)以实现此功能,我们研究了以下属性:

  • AddImage(HTML_Image_Element,“标签”):- 添加带有标签的imageElement。假设我们有三个用于预测的图像,即:img1,img2,img3,分别带有三个标签“ A”,“ B”和“ C”。
  • Train()/ fit():-使用相关标签训练此自定义模型。因此,我们想使用这些图像和相应的标签来创建和训练模型。

  • 保存/加载模型:-为了使预训练的模型可重复使用,我们希望保存训练的模型并在希望使用同一数据集进行预测时加载该模型。 。通过save()函数,我们得到了两个文件,即“ model.json”和“ model.weights.bin”。

  • Predict():-:加载成功训练的模型后,我们便可以将该模型加载到另一个页面中,以便用户可以使用相关标签预测图像,然后图像将返回带有每个图像的标签的预测响应。假设用户想要预测“ img1”时,它将预测显示为类名“ A”,类似地,对于“ img2”与类名“ B”进行预测,对于“ img3”与类名“ C”进行预测,置信度值为好吧。

已实现的步骤:在上述要求中,我们成功实现了以下几点:

  • 在第一个网页中,我们创建一个顺序模型,并将图像(通过将其转换为张量)添加到模型中。然后,我们使用这些图像训练/拟合该模型并关联标签。

  • 训练后,我们可以轻松地保存()这个自定义的预训练模型以进行进一步的预测。在此期间,用户可以从训练期间使用的数据集中预测出任何特定图像,该模型会给出带有关联标签的响应,即:如果用户希望对“ img1”进行预测,则模型会以标签为“ A'。

  • 保存模型后,我们现在可以在第二个网页上,用户可以使用预先训练的模型对图像进行预测,而无需进行任何训练。在此阶段,我们可以通过以下方式加载保存的模型并获得预测:

prediction:::0.9590839743614197,0.0006004410679452121,0.002040663966909051,0.001962134148925543,0.008351234719157219,0.004203603137284517,0.010159854777157307,0.007813011296093464,0.0013025108492001891,0.004482310265302658
在此响应(预测:: :)中,我们没有得到图像的任何类名/标签。它只是基于张量/图像返回置信度值。

仍需要实现:在训练自定义模型时,我们要在自定义模型中同时添加图像和标签。但是,当我们保存model.json文件时,然后在此.json文件中,我们无法在添加和训练模型时找到与图像关联的标签(“ A”,“ B”和“ C”)。因此,当我们在第二页中添加此model.json并尝试预测时,该模型不会显示关联标签。 训练/拟合时,我们无法在模型(model.json)中添加标签。 请找到代码和随附的网页屏幕截图,以便更好地理解。

以下是一些附件/示例代码,可帮助您了解要求: 预测页面(第二个网页):Prediction Page With Pre-trained model

在此处找到两个模型文件(.json和.bin): Custom model files

下面是两个页面的代码:

//2nd Page which is for prediction -
<apex:page sidebar="false" >
<head>
    <title>Predict with tensorflowJS</title> 
    <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.3.1/jquery.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.15.1"> </script>
</head>
<div class="container mt-5">
    <div class="row">
        <div class="col-12">
            <div class="progress progress-bar progress-bar-striped progress-bar-animated mb-2">Loading Model</div>
        </div>
    </div>
    <div class="row">
        <div class="col-3">
            <select id="model-selector" class="custom-select" >
                <option>mobilenet</option>
            </select>
        </div>
    </div>
    <input type="file" id="load" multiple="multiple" /><br/>
    <label for="avatar">Load Model:</label>

    <div class="row">
        <input id ="image-selector" class="form-control border-0" type="file"/>
    </div>

    <div class="col-6">
        <button id="predict-button" class="btn btn-dark float-right">Predict</button>
    </div>
</div>
<div class="row">
    <div class="col">
        <h2>Prediction</h2>
        <ol id="prediction-list"></ol>
    </div>
</div>
<div class="row">
    <div class="col-12">
        <h2 class="ml-3">Image</h2>
        <img id="selected-image" class="ml-3" src="" crossorigin="anonymous" width="400" height="300"/>
    </div>
</div>
<script>
$(document).ready()
{
    $('.progress-bar').hide();
}
$("#image-selector").change(function(){
    let reader = new FileReader();

    reader.onload = function(){
        let dataURL = reader.result;
        $("#selected-image").attr("src",dataURL);
        $("#prediction-list").empty();
    }
    let file = $("#image-selector").prop('files')[0];
    reader.readAsDataURL(file);
});

$("#model-selector").ready(function(){
    loadModel($("#model-selector").val());
    $('.progress-bar').show();
})

let model;
let cutomModelJson;
let cutomModelbin;
async function loadModel(name){
    $("#load").change(async function(){
        for (var i = 0; i < $(this).get(0).files.length; ++i) {
            console.log('AllFiles:::'+JSON.stringify($(this).get(0).files[i]));
            if($(this).get(0).files[i].name == 'my-model-1.json'){
                cutomModelJson = $(this).get(0).files[i];
            }else{
                cutomModelbin = $(this).get(0).files[i];
            }
        }
        console.log('cutomModelJson::'+cutomModelJson.name+'cutomModelbin::'+cutomModelbin.name);
        model = await tf.loadModel(tf.io.browserFiles([cutomModelJson, cutomModelbin]));

        console.log('model'+JSON.stringify(model));
    });
}

$("#predict-button").click(async function(){
    let image= $('#selected-image').get(0);
    console.log('image',image);
    let tensor = preprocessImage(image,$("#model-selector").val());
    const resize_image = tf.reshape(tensor, [1, 224, 224, 3],'resize');
    console.log('tensor',tensor);
    console.log('resize_image',resize_image);
    console.log('model1',model);
    let prediction = await model.predict(tensor).data();
    console.log('prediction:::'+ prediction);

    let top5 = Array.from(prediction)
    .map(function(p,i){
        return {
            probability: p,
            className: prediction[i]
        };
    }).sort(function(a,b){
        return b.probability-a.probability;
    }).slice(0,1);

    $("#prediction-list").empty();
    top5.forEach(function(p){
        $("#prediction-list").append(`<li>${p.className}:${p.probability.toFixed(6)}</li>`);
    });
});

function preprocessImage(image,modelName)
{
    let tensor=tf.browser.fromPixels(image)
    .resizeNearestNeighbor([224,224])
    .toFloat();
    console.log('tensor pro', tensor);
    if(modelName==undefined)
    {
        return tensor.expandDims();
    }
    if(modelName=="mobilenet")
    {
        let offset=tf.scalar(127.5);
        console.log('offset',offset);
        return tensor.sub(offset)
        .div(offset)
        .expandDims();
    }
    else
    {
        throw new Error("UnKnown Model error");
    }
}
</script>

//1st Page which is for Create and train model -
<apex:page sidebar="false">
<head>
    <title>Add image and train model with tensorflowJS</title> 
    <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.3.1/jquery.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.15.1"> </script>
    <script src="https://unpkg.com/@tensorflow-models/mobilenet"></script>
</head>
<div class="container mt-5">
    <div class="row">
        <div class="col-12">
            <div class="progress progress-bar progress-bar-striped progress-bar-animated mb-2">Loading Model</div>
        </div>
    </div>
    <div class="row">
        <div class="col-3">
            <select id="model-selector" class="custom-select" >
                <option>mobilenet</option>                    
            </select>
        </div>
    </div>

    <div class="row">
        <input id ="image-selector" class="form-control border-0" type="file"/>
    </div>

    <div class="col-6">
        <button id="predict-button" class="btn btn-dark float-right">Predict</button>
    </div>
</div>
<div class="row">
    <div class="col">
        <h2>Prediction></h2>
        <ol id="prediction-list"></ol>
    </div>
</div>
<div class="row">
    <div class="col-12">
        <h2 class="ml-3">Image</h2>
        <img id="selected-image" src="{!$Resource.cat}" crossorigin="anonymous" width="400" height="300" />
    </div>
</div>
<script>

$("#model-selector").ready(function(){
    loadModel($("#model-selector").val());
})

let model;
async function loadModel(name){
    model = tf.sequential();
    console.log('model::'+JSON.stringify(model));
}

$("#predict-button").click(async function(){
    let image= $('#selected-image').get(0);
    console.log('image:::',image);
    let tensor = preprocessImage(image,$("#model-selector").val());
    const resize_image = tf.reshape(tensor, [1, 224, 224, 3],'resize');
    console.log('tensorFromImage:::',resize_image);
    // Labels
    const label = ['cat'];
    const setLabel = Array.from(new Set(label));
    const ys = tf.oneHot(tf.tensor1d(label.map((a) => setLabel.findIndex(e => e === a)), 'int32'), 10)
    console.log('ys:::'+ys);

    model.add(tf.layers.conv2d({
        inputShape: [224, 224 , 3],
        kernelSize: 5,
        filters: 8,
        strides: 1,
        activation: 'relu',
        kernelInitializer: 'VarianceScaling'
    }));

    model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
    model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
    model.add(tf.layers.flatten({}));
    model.add(tf.layers.dense({units: 64, activation: 'relu'}));
    model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
    model.compile({
        loss: 'meanSquaredError',
        optimizer : 'sgd'
    })

    // Train the model using the data.
    model.fit(resize_image, ys, {epochs: 100}).then((loss) => {
        const t = model.predict(resize_image);
        console.log('Prediction:::'+t);
        pred = t.argMax(1).dataSync(); // get the class of highest probability
        const labelsPred = Array.from(pred).map(e => setLabel[e])
        console.log('labelsPred:::'+labelsPred);
        const saveResults = model.save('downloads://my-model-1');
        console.log(saveResults);
    }).catch((e) => {
        console.log(e.message);
    })


});


function preprocessImage(image,modelName)
{
    let tensor = tf.browser.fromPixels(image)
    .resizeNearestNeighbor([224,224])
    .toFloat();
    console.log('tensor pro:::', tensor);
    if(modelName==undefined)
    {
        return tensor.expandDims();
    }
    if(modelName=="mobilenet")
    {
        let offset=tf.scalar(127.5);
        console.log('offset:::',offset);
        return tensor.sub(offset)
        .div(offset)
        .expandDims();
    }
    else
    {
        throw new Error("UnKnown Model error");
    }
}
</script>

请告诉我们是否在任何地方出错或需要执行任何其他步骤来实现此任务。

1 个答案:

答案 0 :(得分:0)

保存的模型不包含标签名称。保存的模型包含模型的拓扑和体系结构的权重。

加载保存的模型后,可以预测最后一层的每个给定类的置信度。从这一预测中,一个人只能说出第一类是最有可能还是第二类或第三类,……而一个人却不能说出该类是“ A”,“ B”还是“ C”,.. 。 举些例子。实际上,即使是初始模型也无法做到这一点。这是另一种处理。

通常,在进行分类时,在拟合模型之前会进行一次热编码。因此,该模型不具有标签“ A”,“ B”或“ C”的含义。例如,对于三个类,它仅具有“ 100”,“ 010”,“ 001”的语义。给定一个张量,一个预测可以是[0.1,0.3,0.6],指示输入最有可能属于第三类。

在模型输出后,要获得标签名称,必须进行如下的反向编码

// take the index i of the highest probability
// indexing the element i of the array of labels

在此过程中,模型始终不了解标签名称。因此,此信息不会保存在任何地方。因此,如果将加载的模型用于不同阶段的预测,则应该有一种方法可以在所有这些阶段传递标签名称数组。这条信息可以来自服务器,也可以保存在localStorage上-人们可以考虑多种方式,除了期望模型自身加载即可。