Tensorflow.js加载的模型的性能比Keras模型差得多

时间:2019-10-13 01:07:31

标签: machine-learning keras tensorflow.js tensorflowjs-converter

因此,我正在尝试使用Keras创建“狗对猫”图像分类模型。我的目标之一是创建一个使用Tensorflow.js部署模型的网站。我已经使用Flask作为服务器成功部署了该模型。

主要问题是模型Tensorflow.js的性能比普通keras模型差很多。当使用普通角膜时,我的模型在测试数据上获得了大约90%的精度。但是,当在tensorflow.js中使用时,该模型未获得单个正确的测试图像。对于解决此问题的帮助或提示,我们将不胜感激。

templates / index.html

<!DOCTYPE html>
<html>
  <head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width">
    <title>repl.it</title>

    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>

    <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
    <link href="{{ url_for('static', filename='index.css') }}" rel="stylesheet" type="text/css" />
    <link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
  </head>
  <body onload="$('#result').hide();$('#continue').hide();">
    <div class="container-fluid">
      <!-- START HEADER -->
      <div class="row" id="headerRow">
        <div class="col-md d-flex align-items-center" id="headerColumn">
          <h2>Cat<span class='or'>or</span>Dog</h2>
        </div>
      </div>
      <!-- END HEADER -->

      <!-- START BODY -->
      <div class="row bodyRow" id='bodyRow'>
        <div class="col-md d-flex align-items-center bodyColumn">
          <div class="body">
            <form class="d-flex align-items-center  justify-content-center imageSubmitForm" method="POST" enctype="multipart/form-data">
              <label class="d-flex align-items-center justify-content-center" for='imageInputField'>
                <i class="material-icons">file_upload</i>

                <p id='result'></p>
                <br/>
                <p id='continue'>Press Anywhere to continue...</p>
              </label>
              <input class="imageInputField" id='imageInputField' type='file' onchange='getPrediction(url)'/>
            </form>
          </div>
        </div>
      </div>
      <!-- END BODY -->

      <!-- START RESULT -->
      <div class="row resultRow">
        <div class="col-md-6 classResultColumn">
            <div class="d-flex align-items-center justify-content-center classResultBox">
                <p id='classResult'></p>
            </div>
        </div>
        <div class="col-md-6 scoreResultColumn">
            <div class="d-flex align-items-center justify-content-center scoreResultBox">
                <p id='scoreResult'></p>
            </div>
        </div>
      </div>
      <!-- END RESULT -->

      <!-- START FOOTER -->
      <!--
      <div class="row d-flex align-items-center footerRow" id='footerRow'>
        <center><a src="#">Source Code</a></center>
      </div>
      -->
      <!-- FOOTER -->
    </div>

    <!-- START SCRIPTS -->
    <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.4.1/jquery.min.js"></script>
    <script src="https://code.jquery.com/jquery-3.4.1.min.js"></script>
    <script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js" integrity="sha384-ApNbgh9B+Y1QKtv3Rn7W3mgPxhU9K/ScQsAP7hUibX39j7fakFPskvXusvfa0b4Q" crossorigin="anonymous"></script>
    <script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js" integrity="sha384-JZR6Spejh4U02d8jOt6vLEHfe/JQGiRRSQQxSfFWpi1MquVdAyjUar5+76PVCmYl" crossorigin="anonymous"></script>
    <script src="{{ url_for('static', filename='index.js') }}"></script>   
    <!-- END SCRIPTS -->
  </body>
</html>

static / index.js

let fileInput = document.getElementById("imageInputField");
let classResultElement = document.getElementById("classResult");
let scoreResultElement = document.getElementById("scoreResult");
let url = "/model";

let model;
let file;
let data;
let responseContent;
let features;
let predictedClass;

let getPrediction = async(url) => {
    if (!model)
        model = await tf.loadLayersModel(url);

    file = fileInput.files[0];
    data = new FormData();
    data.append("file", file);

    $.ajax({
        url : "/api/preprocess",
        type: 'POST',
        data: data,
        traditional: true,
        processData: false,
        contentType: false,

        success: function(response)
        {
            responseContent = JSON.parse(response)['image'];

            if (responseContent != "False")
            {
                features = tf.tensor(responseContent);
                score = model.predict(features).dataSync();

                alert(score);

                if (score >= 0.5) {
                    predictedClass = "Dog";

                    classResultElement.innerHTML = "<b>Predicted Class:</b> " + predictedClass;
                    scoreResultElement.innerHTML = "<b>Certainty:</b> " + score*100.0 + "%";
                } else {
                    predictedClass = "Cat";

                    classResultElement.innerHTML = "<b>Predicted Class:</b>" + predictedClass;
                    scoreResultElement.innerHTML = "<b>Certainty:</b> " + (1.0 - score) * 100.0 + "%";
                }

                alert(predictedClass);
            }
        }
    });
}

app.py

import flask
from flask_cors import CORS
from werkzeug import secure_filename
import time
import os
import keras
import numpy as np
import json
import matplotlib.pyplot as plt

app = flask.Flask(__name__)
CORS(app)

UPLOADS_DIR = "uploads/"

@app.route("/")
def index():
  """
    Fetch and return the main homepage. 
  """
  return flask.render_template("index.html")

@app.route("/favicon.ico")
def get_favicon():
  """
    Return a fake message in order to silence the error caused by a favicon not being found.
  """
  return "Favicon Does Not Exist"

@app.route("/model")
def get_modeljson():
  """
    Get the model.json file and return it's contents.
  """
  with open("model/model.json", "r") as f:
    return f.read()

@app.route("/<path:path>")
def get_shard(path):
  """
    get the binary weight file for the model (also known as a shard).

    path    =>    the filename of the binary weight file.
  """
  return flask.send_from_directory("model/", path)

@app.route("/api/preprocess", methods=['POST'])
def preprocess():
  """
    takes an image object from an AJAX request and returns a normalized list of the values.
  """
  if flask.request.method == 'POST':
    file = flask.request.files['file']
    filename = secure_filename(file.filename)
    new_filename = "{}_{}".format(time.time(), filename)
    file.save(os.path.join(UPLOADS_DIR, new_filename))

    img_obj = keras.preprocessing.image.load_img(os.path.join(UPLOADS_DIR, new_filename), target_size=(224, 224))
    img_arr = keras.preprocessing.image.img_to_array(img_obj).reshape(1, 224, 224, 3)
    img_arr = np.divide(img_arr, 255.)

    os.remove(os.path.join(UPLOADS_DIR, new_filename))
    return json.dumps({"image":img_arr.tolist()})
  return json.dumps({"image":"False"})

if __name__ == "__main__":
  app.run()

您可以找到用于训练模型here的kaggle笔记本的URL。 您可以找到用于测试代码here的笔记本。

非常感谢您的帮助或提示。

1 个答案:

答案 0 :(得分:0)

喝了一吨咖啡,几乎没有睡觉,我遇到了一个解决方案。显然,WebGL的内部构造与Python中的Tensorflow内部构造不同。这里的解决方法是禁用WebGL。

在加载图形模型之前,添加...

tf.ENV.set("WEBGL_PACK", false);

这将禁用WebGL并强制TFJS发挥更像python的作用!