如果将代码放入函数中,代码将停止工作,为什么?

时间:2019-06-14 19:03:41

标签: p5.js tensorflow.js

我收到错误消息:

  

未捕获(承诺)TypeError:model.predict不是函数

但是useModel函数中的代码,如果我将其移到Im调用useModel函数的地方,它将起作用。我不明白为什么。但这对我没有帮助,因为我将需要能够在自己的函数中在设置函数之外进行预测。

我认为它与promise有关,并且我确实尝试将async放在useModel函数前面。但是不确定为什么会有帮助。

也许以某种巧妙的方式使用.then?

let data;
let xs;
let ys;

function preload(){
  data = loadJSON('gridson.json');
}

function setup() {
  createCanvas(40, 40);


  // prepare data for tensor
  let board = [];
  for (let i =0; i < data.in.length; i++){
    let norm = [];
    for (let j =0; j < 200; j++){
      norm.push(data['in'][i]['arr'][j] / 2);
    }
    board.push(norm);
  }

  xs = tf.tensor2d(board);


  let labelList = ['left', 'right', 'rotate', 'fall'];
  let label = [];
  for (let record of data.in){
    label.push(labelList.indexOf(record.move));
  }

  let labelTensor = tf.tensor1d(label, 'int32');

  ys = tf.oneHot(labelTensor, 4).cast('float32');
  labelTensor.dispose();


  // create the model
  let model = tf.sequential();
  let hidden = tf.layers.dense({
    units: 16,
    inputShape: [200],
    activation: 'sigmoid'
  });
  let output = tf.layers.dense({
    units: 4,
    activation: 'softmax'
  });
  model.add(hidden);
  model.add(output);


  // create an optimizer
  const lr = 0.1;
  const optimizer = tf.train.sgd(lr);

  model.compile({
    optimizer: optimizer,
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy']
  });


  // train model
  model.fit(xs, ys, {
    shuffle: true,
    validationSplit: 0.1,
    epochs: 1,
    callbacks: {
      onEpochEnd: (epoch, logs) => {
        console.log(epoch);
      },
      onBatchEnd: async (batch, logs) => {
        await tf.nextFrame();
      },
      onTrainEnd: () => {
        console.log('finished');

        // use the model
        useModel();

      },
    },
  });

}

function useModel(){
  tf.tidy(() => {
    let grid = [];
    for (let h =0; h < 200; h++){
      grid.push(0); // create junk test data
    }
    const input = tf.tensor2d([grid]);
    let results = model.predict(input);
    let argMax = results.argMax(1);
    let index = argMax.dataSync()[0];
    let label = labelList[index];
    console.log(label);
  });
}

function draw() {
  background(150);
}

1 个答案:

答案 0 :(得分:1)

我只需要全局声明模型(和labelList)变量,它似乎可以工作。