转移学习Tensorflow.js大小/形状错误

时间:2019-11-19 10:49:02

标签: javascript python tensorflow machine-learning tensorflow.js

我试图通过在Tensorflow.js中使用knnClassifier和mobileNet图像识别模型来应用转移学习,但是,我收到以下错误:

尺寸(28672)必须与形状28,3072的乘积匹配

我不知道该如何解决这个问题,我尝试创建tensor3D,使用双线性和最近邻居调整大小,但无济于事。我想知道这里是否有人可以检查一下。

请注意,我的想法是训练来自某些文件夹的图像,并使用knnClassifier的添加示例将其分配给其类。我有一个从路径读取图像的函数,以及一个用于训练模型并根据图像做出预测的异步函数。

............................................... .................................................. < / p>

const tf = require('@tensorflow/tfjs');
//MobileNet : pre-trained model for TensorFlow.js
const mobilenet = require('@tensorflow-models/mobilenet');
//The module provides native TensorFlow execution
//in backend JavaScript applications under the Node.js runtime.
const tfnode = require('@tensorflow/tfjs-node');

const knnClassifier = require('./node_modules/@tensorflow-models/knn-classifier/dist/knn-classifier');

var glob = require('glob')
//The fs module provides an API for interacting with the file system.
const fs = require('fs');

const readImage = path => {
  //reads the entire contents of a file.
  //readFileSync() is synchronous and blocks execution until finished.
  const imageBuffer = fs.readFileSync(path);
  //Given the encoded bytes of an image,
  //it returns a 3D or 4D tensor of the decoded image. Supports BMP, GIF, JPEG and PNG formats.
  var tfimage = tfnode.node.decodeImage(imageBuffer);

  // const t3d = tf.tensor3d(Array.from(tfimage.dataSync()),[tfimage.shape[0], tfimage.shape[1], 1])
  const smalImg = tf.image.resizeNearestNeighbor(tfimage, [32, 32]);
  const resized = tf.cast(smalImg, 'float32');

  // t3d.reshape([32,32,3])
  // var smalImg = tf.image.resizeBilinear(tfimage, [368, 432]);
  // const resized = tf.cast(smalImg, 'float32');
  return resized;
}


var mainDirectory = "./img_samples/";

const imageClassification = async path => {
  const classifier = await knnClassifier.create();

  const image = await readImage(path);
  // Load the model.
  const model = await mobilenet.load();
  // Classify the image.
  const predictions = await model.classify(image);
  // print results on terminal
  console.log('Classification Results:', predictions);

  var folders = fs.readdirSync(mainDirectory);

  var filesPerClass = [];
  for(var i=0;i<folders.length;i++){
    files = fs.readdirSync(mainDirectory+folders[i]);
    var files_complete = [];
    for(var j=0;j<files.length;j++){
      files_complete.push(mainDirectory+folders[i]+"/"+files[j]);
    }
    filesPerClass.push(files_complete);
  }

  for(var i=0;i<filesPerClass.length;i++){
    for(var j=0;j<filesPerClass[i].length;j++){
      imageSample = readImage(filesPerClass[i][j]);
      console.log(imageSample);
      activation = await model.infer(imageSample, 'conv_preds');  //main directory
      classifier.addExample(activation,i);
    }
  }

  console.log(readImage('./hospitalTest.jpg'))
  const predictionsTest = await classifier.predictClass(readImage('./hospitalTest.jpg'));   
  console.log('classficationTest:',predictionsTest);
}

if (process.argv.length !== 3) throw new Error('Incorrect arguments: node classify.js <IMAGE_FILE>');

imageClassification(process.argv[2]);

1 个答案:

答案 0 :(得分:0)

由于使用移动网络节点的输出来训练knn分类器,因此预测也需要同样进行

outputMobilenet = await model.infer(readImage('./hospitalTest.jpg'), 'conv_preds')
predicted = await classifier.predictClass(outputMobilenet)