分类器输出错误

时间:2019-03-07 10:28:55

标签: tensorflow.js

我是机器学习的新手,我使用了mnist演示模型来训练猫和狗的分类器。但这似乎效果不佳。以下是该模型的一些示意图:

onEpochEnd onBatchEnd perClassAccuracy

info

该模型似乎总是将任何输入预测为猫。 这是我的代码。请帮助我。

index.js:

import {IMAGE_H, IMAGE_W, MnistData} from './data.js';


import * as ui from './ui.js';


let classNum = 0;
function createConvModel() {

    const model = tf.sequential();
    model.add(tf.layers.conv2d({
        inputShape: [IMAGE_H, IMAGE_W, 3],
        kernelSize: 5,
        filters: 32,
        activation: 'relu'
    }));

    model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));

    model.add(tf.layers.conv2d({kernelSize: 5, filters: 32, activation: 'relu'}));

    model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));

    model.add(tf.layers.conv2d({kernelSize: 5, filters: 64, activation: 'relu'}));

    model.add(tf.layers.flatten({}));

    model.add(tf.layers.dense({units: 64, activation: 'relu'}));

    model.add(tf.layers.dense({units: classNum, activation: 'softmax'}));

    return model;
}


function createDenseModel() {
    const model = tf.sequential();
    model.add(tf.layers.flatten({inputShape: [IMAGE_H, IMAGE_W, 3]}));
    model.add(tf.layers.dense({units: 42, activation: 'relu'}));
    model.add(tf.layers.dense({units: classNum, activation: 'softmax'}));
    return model;
}

async function train(model, fitCallbacks) {
    ui.logStatus('Training model...');

    const optimizer = 'rmsprop';

    model.compile({
        optimizer,
        loss: 'categoricalCrossentropy',
        metrics: ['accuracy'],
    });

    const batchSize = 64;

    const trainEpochs = ui.getTrainEpochs();

    let trainBatchCount = 0;

    const trainData = data.getTrainData();
    const valData = data.getValData();
    const testData = data.getTestData();


    await model.fit(trainData.xs, trainData.labels, {
        batchSize:batchSize,
        validationData:[valData.xs,valData.labels],
        shuffle:true,
        epochs: trainEpochs,
        callbacks: fitCallbacks
    });
    console.log("complete");
    const classNames = ['cat','dog'];
    const [preds, labels] = doPrediction(model,testData);
    const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
    const container = { name: 'Accuracy', tab: 'Evaluation' };
    tfvis.show.perClassAccuracy(container, classAccuracy, classNames);

}

function doPrediction(model,testData) {
    const testxs = testData.xs;
    const labels = testData.labels.argMax([-1]);
    const preds = model.predict(testxs).argMax([-1]);

    testxs.dispose();
    return [preds, labels];
}

function createModel() {
    let model;
    const modelType = ui.getModelTypeId();
    if (modelType === 'ConvNet') {
        model = createConvModel();
    } else if (modelType === 'DenseNet') {
        model = createDenseModel();
    } else {
        throw new Error(`Invalid model type: ${modelType}`);
    }
    return model;
}

async function watchTraining(model) {
    const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
    const container = {
        name: 'charts', tab: 'Training', styles: { height: '1000px' }
    };
    const callbacks = tfvis.show.fitCallbacks(container, metrics);
    return train(model, callbacks);
}

let data;
async function load() {
    tf.disableDeprecationWarnings();
    classNum = await localforage.getItem('classNum');
    tfvis.visor();
    data = new MnistData();
    await data.load();
}


ui.setTrainButtonCallback(async () => {
    ui.logStatus('Loading data...');
    await load();

    ui.logStatus('Creating model...');
    const model = createModel();
    model.summary();

    ui.logStatus('Starting model training...');

    await watchTraining(model);
});

data.js:

export const IMAGE_H = 64;
export const IMAGE_W = 64;
const IMAGE_SIZE = IMAGE_H * IMAGE_W;
let NUM_CLASSES = 0;
let trainImagesLabels;
let testLabels;
let trainImages ;
let testImages ;
let validateImages;
let validateLabels;
let validateSplit = 0.2;
let modelId;
let classNum;

/**
 * A class that fetches the sprited MNIST dataset and provide data as
 * tf.Tensors.
 */
export class MnistData {
  constructor() {}

    //shuffle
    static shuffleSwap(arr1,arr2) {
        if(arr1.length == 1) return {arr1,arr2};
        let i = arr1.length;
        while(--i > 1) {
            let j = Math.floor(Math.random() * (i+1));
            [arr1[i], arr1[j]] = [arr1[j], arr1[i]];
            [arr2[i], arr2[j]] = [arr2[j], arr2[i]];
        }
        return {arr1,arr2};
    }

  async load() {
    //get data from localforage
    this.trainImages = await localforage.getItem('dataset');
    this.trainImagesLabels = await localforage.getItem('datasetLabel');
    this.modelId = await localforage.getItem('modelId');
    this.classNum = await localforage.getItem('classNum');

      this.trainImages.shift();
      this.trainImagesLabels.shift();

      //construct the validateData
      let status = false;
      let maxVal = Math.floor(this.trainImages.length * 0.2);

      this.validateImages = new Array();
      this.validateLabels = new Array();
    for(let i=0;i<maxVal;i++){
      if(status){
          this.validateImages.push(this.trainImages.pop());
          this.validateLabels.push(this.trainImagesLabels.pop());
          status = false;
      }else{
          this.validateImages.push(this.trainImages.shift());
          this.validateLabels.push(this.trainImagesLabels.shift());
          status = true;
      }
    }
    //construct the testData
      this.testImages = new Array();
      this.testLabels = new Array();
      for(let i=0;i<maxVal;i++){
          if(status){
              this.testImages.push(this.trainImages.pop());
              this.testLabels.push(this.trainImagesLabels.pop());
              status = false;
          }else{
              this.testImages.push(this.trainImages.shift());
              this.testLabels.push(this.trainImagesLabels.shift());
              status = true;
          }
      }
    //shuffle
      let val = MnistData.shuffleSwap(this.validateImages,this.validateLabels);
      this.validateImages = val.arr1;
      this.validateLabels = val.arr2;
      let train = MnistData.shuffleSwap(this.trainImages,this.trainImagesLabels);
      this.trainImages = train.arr1;
      this.trainImagesLabels = train.arr2;
  }



  getTrainData() {
    const xs = tf.tensor4d(this.trainImages);
    const labels = tf.oneHot(tf.tensor1d(this.trainImagesLabels,'int32'),this.classNum);
    return {xs, labels};
  }



    getValData() {
        const xs = tf.tensor4d(this.validateImages);
        const labels = tf.oneHot(tf.tensor1d(this.validateLabels,'int32'),this.classNum);
        return {xs, labels};
    }

    getTestData() {
        const xs = tf.tensor4d(this.testImages);
        const labels = tf.oneHot(tf.tensor1d(this.testLabels,'int32'),this.classNum);
        return {xs, labels};
    }
}
我在开始时添加了一些图片。

  
  //getclassNum
  function getClassNum(files) {
        let classArr = new Array();
        let dirArr = new Array();
        let imageNum = 0;
        for (let i = 0; i < files.length; i++) {
            if (files[i].type.split('/')[0] == 'image' && files[i].type.split('/')[1] == 'jpeg') {
                dirArr = files[i].webkitRelativePath.split('/');
                let currentClassIndex = dirArr.length - 2;
                let isExist = false;
                if (currentClassIndex <= 0)
                    isExist = true;
                else {
                    imageNum++;
                }
                if (classArr == null) {
                    classArr.push(dirArr[currentClassIndex]);
                }
                for (let j = 0; j < classArr.length; j++) {
                    if (classArr[j] == dirArr[currentClassIndex]) {
                        isExist = true;
                    }
                }
                if (!isExist) {
                    classArr.push(dirArr[currentClassIndex]);
                }
            }
        }
        let classNum = classArr.length;
        return {classNum, imageNum, classArr};
    }
  //get nested array
  function getDataset(files, classArr,imgNum) {
        let trainLabelArr = new Array();
        let trainDataArr = new Array();
        for (let i = 0; i < files.length; i++) {
            if (files[i].type.split('/')[0] == 'image'&& files[i].type.split('/')[1] == 'jpeg') {
                let dirArr = files[i].webkitRelativePath.split('/');
                let currentClassIndex = dirArr.length - 2;
                if (currentClassIndex >= 0) {
                    for(let j=0;j<classArr.length;j++){
                        if(dirArr[currentClassIndex]==classArr[j]){
                            let reader = new FileReader();
                            reader.readAsDataURL(files[i]);
                            reader.onload = function () {
                                document.getElementById('image').setAttribute( 'src', reader.result);
                                let tensor= tf.browser.fromPixels(document.getElementById('image'));
                                let nest =  tensor.arraySync();

                                trainDataArr.push(nest);
                                trainLabelArr.push(j);
                            }
                        }
                    }
                }
            }
        }
        return{trainDataArr,trainLabelArr,trainDataLength}
    }
  //getfiles
  async function fileChange(that) {
        let files = that.files;
        let container = getClassNum(files);
          
        let data = getDataset(files, container.classArr,container.imageNum);
        let trainDataArr = data.trainDataArr;
        let trainLabelArr = data.trainLabelArr;

        setTimeout(function () {
  
            localforage.setItem('dataset',trainDataArr,function (err,result) {
                 
            });
             localforage.setItem('datasetLabel',trainLabelArr,function (err,result) {

            });
             localforage.setItem('modelId',modelId,function (err,result) {

            });
             localforage.setItem('classNum',container.classNum,function (err,result) {

            });
        },container.imageNum * 10);

        }
   }

1 个答案:

答案 0 :(得分:0)

让我回答我的问题。经过一天的测试,我发现此模型需要大量数据。每个类别至少需要1,000张图片。如果训练数据不足,则模型只能输出一个结果。此外,该模型在识别带有较少字符(例如字母和符号)的对象时表现非常好,而在识别动物或自然环境方面却表现不佳。