在tensorflow.js中分类

时间:2019-04-22 23:32:21

标签: tensorflow tensorflow.js

this tutorial之后,我想在tensorflowjs中加载并使用一个模型,然后使用classify方法对输入进行分类。

我这样加载并执行模型:

const model = await window.tf.loadGraphModel(MODEL_URL);

const threshold = 0.9;
const labelsToInclude = ["test1"];

model.load(threshold, labelsToInclude).then(model2 => {
    model2.classify(["test sentence"])
      .then(predictions => {
    console.log('prediction: ' + predictions);
    return true;
  })
});

但是我得到了错误:

  

TypeError:model2.classify在App.js:23中不是函数

如何在tensorflowjs中正确使用分类方法?

1 个答案:

答案 0 :(得分:1)

本教程使用特定的模型(toxicity)。它的loadclassify函数不是Tensorflow.js模型本身的功能,而是由此特定模型实现的。

查看API,以大致了解模型支持的功能。如果加载GraphModel,则要使用model.predict(或execute)函数来执行模型。

因此,您的代码应如下所示:

const model = await window.tf.loadGraphModel(MODEL_URL);
const input = tf.tensor(/* ... */); // whatever a valid tensor looks like for your model
const predictions = model.predict([input]);
console.log('prediction: ' + predictions);