我是机器学习的新手,我使用了mnist演示模型来训练猫和狗的分类器。但这似乎效果不佳。以下是该模型的一些示意图:>
该模型似乎总是将任何输入预测为猫。 这是我的代码。请帮助我。
index.js:
import {IMAGE_H, IMAGE_W, MnistData} from './data.js';
import * as ui from './ui.js';
let classNum = 0;
function createConvModel() {
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [IMAGE_H, IMAGE_W, 3],
kernelSize: 5,
filters: 32,
activation: 'relu'
}));
model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
model.add(tf.layers.conv2d({kernelSize: 5, filters: 32, activation: 'relu'}));
model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
model.add(tf.layers.conv2d({kernelSize: 5, filters: 64, activation: 'relu'}));
model.add(tf.layers.flatten({}));
model.add(tf.layers.dense({units: 64, activation: 'relu'}));
model.add(tf.layers.dense({units: classNum, activation: 'softmax'}));
return model;
}
function createDenseModel() {
const model = tf.sequential();
model.add(tf.layers.flatten({inputShape: [IMAGE_H, IMAGE_W, 3]}));
model.add(tf.layers.dense({units: 42, activation: 'relu'}));
model.add(tf.layers.dense({units: classNum, activation: 'softmax'}));
return model;
}
async function train(model, fitCallbacks) {
ui.logStatus('Training model...');
const optimizer = 'rmsprop';
model.compile({
optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
const batchSize = 64;
const trainEpochs = ui.getTrainEpochs();
let trainBatchCount = 0;
const trainData = data.getTrainData();
const valData = data.getValData();
const testData = data.getTestData();
await model.fit(trainData.xs, trainData.labels, {
batchSize:batchSize,
validationData:[valData.xs,valData.labels],
shuffle:true,
epochs: trainEpochs,
callbacks: fitCallbacks
});
console.log("complete");
const classNames = ['cat','dog'];
const [preds, labels] = doPrediction(model,testData);
const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
const container = { name: 'Accuracy', tab: 'Evaluation' };
tfvis.show.perClassAccuracy(container, classAccuracy, classNames);
}
function doPrediction(model,testData) {
const testxs = testData.xs;
const labels = testData.labels.argMax([-1]);
const preds = model.predict(testxs).argMax([-1]);
testxs.dispose();
return [preds, labels];
}
function createModel() {
let model;
const modelType = ui.getModelTypeId();
if (modelType === 'ConvNet') {
model = createConvModel();
} else if (modelType === 'DenseNet') {
model = createDenseModel();
} else {
throw new Error(`Invalid model type: ${modelType}`);
}
return model;
}
async function watchTraining(model) {
const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
const container = {
name: 'charts', tab: 'Training', styles: { height: '1000px' }
};
const callbacks = tfvis.show.fitCallbacks(container, metrics);
return train(model, callbacks);
}
let data;
async function load() {
tf.disableDeprecationWarnings();
classNum = await localforage.getItem('classNum');
tfvis.visor();
data = new MnistData();
await data.load();
}
ui.setTrainButtonCallback(async () => {
ui.logStatus('Loading data...');
await load();
ui.logStatus('Creating model...');
const model = createModel();
model.summary();
ui.logStatus('Starting model training...');
await watchTraining(model);
});
data.js:
export const IMAGE_H = 64;
export const IMAGE_W = 64;
const IMAGE_SIZE = IMAGE_H * IMAGE_W;
let NUM_CLASSES = 0;
let trainImagesLabels;
let testLabels;
let trainImages ;
let testImages ;
let validateImages;
let validateLabels;
let validateSplit = 0.2;
let modelId;
let classNum;
/**
* A class that fetches the sprited MNIST dataset and provide data as
* tf.Tensors.
*/
export class MnistData {
constructor() {}
//shuffle
static shuffleSwap(arr1,arr2) {
if(arr1.length == 1) return {arr1,arr2};
let i = arr1.length;
while(--i > 1) {
let j = Math.floor(Math.random() * (i+1));
[arr1[i], arr1[j]] = [arr1[j], arr1[i]];
[arr2[i], arr2[j]] = [arr2[j], arr2[i]];
}
return {arr1,arr2};
}
async load() {
//get data from localforage
this.trainImages = await localforage.getItem('dataset');
this.trainImagesLabels = await localforage.getItem('datasetLabel');
this.modelId = await localforage.getItem('modelId');
this.classNum = await localforage.getItem('classNum');
this.trainImages.shift();
this.trainImagesLabels.shift();
//construct the validateData
let status = false;
let maxVal = Math.floor(this.trainImages.length * 0.2);
this.validateImages = new Array();
this.validateLabels = new Array();
for(let i=0;i<maxVal;i++){
if(status){
this.validateImages.push(this.trainImages.pop());
this.validateLabels.push(this.trainImagesLabels.pop());
status = false;
}else{
this.validateImages.push(this.trainImages.shift());
this.validateLabels.push(this.trainImagesLabels.shift());
status = true;
}
}
//construct the testData
this.testImages = new Array();
this.testLabels = new Array();
for(let i=0;i<maxVal;i++){
if(status){
this.testImages.push(this.trainImages.pop());
this.testLabels.push(this.trainImagesLabels.pop());
status = false;
}else{
this.testImages.push(this.trainImages.shift());
this.testLabels.push(this.trainImagesLabels.shift());
status = true;
}
}
//shuffle
let val = MnistData.shuffleSwap(this.validateImages,this.validateLabels);
this.validateImages = val.arr1;
this.validateLabels = val.arr2;
let train = MnistData.shuffleSwap(this.trainImages,this.trainImagesLabels);
this.trainImages = train.arr1;
this.trainImagesLabels = train.arr2;
}
getTrainData() {
const xs = tf.tensor4d(this.trainImages);
const labels = tf.oneHot(tf.tensor1d(this.trainImagesLabels,'int32'),this.classNum);
return {xs, labels};
}
getValData() {
const xs = tf.tensor4d(this.validateImages);
const labels = tf.oneHot(tf.tensor1d(this.validateLabels,'int32'),this.classNum);
return {xs, labels};
}
getTestData() {
const xs = tf.tensor4d(this.testImages);
const labels = tf.oneHot(tf.tensor1d(this.testLabels,'int32'),this.classNum);
return {xs, labels};
}
}
//getclassNum
function getClassNum(files) {
let classArr = new Array();
let dirArr = new Array();
let imageNum = 0;
for (let i = 0; i < files.length; i++) {
if (files[i].type.split('/')[0] == 'image' && files[i].type.split('/')[1] == 'jpeg') {
dirArr = files[i].webkitRelativePath.split('/');
let currentClassIndex = dirArr.length - 2;
let isExist = false;
if (currentClassIndex <= 0)
isExist = true;
else {
imageNum++;
}
if (classArr == null) {
classArr.push(dirArr[currentClassIndex]);
}
for (let j = 0; j < classArr.length; j++) {
if (classArr[j] == dirArr[currentClassIndex]) {
isExist = true;
}
}
if (!isExist) {
classArr.push(dirArr[currentClassIndex]);
}
}
}
let classNum = classArr.length;
return {classNum, imageNum, classArr};
}
//get nested array
function getDataset(files, classArr,imgNum) {
let trainLabelArr = new Array();
let trainDataArr = new Array();
for (let i = 0; i < files.length; i++) {
if (files[i].type.split('/')[0] == 'image'&& files[i].type.split('/')[1] == 'jpeg') {
let dirArr = files[i].webkitRelativePath.split('/');
let currentClassIndex = dirArr.length - 2;
if (currentClassIndex >= 0) {
for(let j=0;j<classArr.length;j++){
if(dirArr[currentClassIndex]==classArr[j]){
let reader = new FileReader();
reader.readAsDataURL(files[i]);
reader.onload = function () {
document.getElementById('image').setAttribute( 'src', reader.result);
let tensor= tf.browser.fromPixels(document.getElementById('image'));
let nest = tensor.arraySync();
trainDataArr.push(nest);
trainLabelArr.push(j);
}
}
}
}
}
}
return{trainDataArr,trainLabelArr,trainDataLength}
}
//getfiles
async function fileChange(that) {
let files = that.files;
let container = getClassNum(files);
let data = getDataset(files, container.classArr,container.imageNum);
let trainDataArr = data.trainDataArr;
let trainLabelArr = data.trainLabelArr;
setTimeout(function () {
localforage.setItem('dataset',trainDataArr,function (err,result) {
});
localforage.setItem('datasetLabel',trainLabelArr,function (err,result) {
});
localforage.setItem('modelId',modelId,function (err,result) {
});
localforage.setItem('classNum',container.classNum,function (err,result) {
});
},container.imageNum * 10);
}
}
答案 0 :(得分:0)
让我回答我的问题。经过一天的测试,我发现此模型需要大量数据。每个类别至少需要1,000张图片。如果训练数据不足,则模型只能输出一个结果。此外,该模型在识别带有较少字符(例如字母和符号)的对象时表现非常好,而在识别动物或自然环境方面却表现不佳。