Tensorflow.js预测的上限似乎为1?

时间:2018-08-06 20:16:05

标签: javascript tensorflow neural-network tensorflow.js

我正在玩一些从youtube教程中获得的,可以预测花朵数据的tensorflow代码。这是脚本(训练数据分配给变量“ iris”,测试数据分配给变量“ irisTesting”:

const trainingData = tf.tensor2d(iris.map(item => [
    item.sepal_length, item.petal_length, item.petal_width,
]));
const outputData = tf.tensor2d(iris.map(item => [
    item.species === "setosa" ? 1 : 0,
    item.species === "virginica" ? 1 : 0,
    item.species === "versicolor" ? 1 : 0,
    item.sepal_width
]));
const testingData = tf.tensor2d(irisTesting.map(item => [
    item.sepal_length, item.petal_length, item.petal_width
]));

const model = tf.sequential();

model.add(tf.layers.dense({
    inputShape: [3],
    activation: "sigmoid",
    units: 5,
}));
model.add(tf.layers.dense({
    inputShape: [5],
    activation: "sigmoid",
    units: 4,
}));
model.add(tf.layers.dense({
    activation: "sigmoid",
    units: 4,
}));
model.compile({
    loss: "meanSquaredError",
    optimizer: tf.train.adam(.06),
});
const startTime = Date.now();
model.fit(trainingData, outputData, {epochs: 100})
    .then((history) => {
         //console.log(history);
        console.log("Done training in " + (Date.now()-startTime) / 1000 + " seconds.");
        model.predict(testingData).print();
    });

当控制台打印预测的sepal_width时,它似乎有一个上限1。训练数据的sepal_width值远远超过1,但是这里记录的数据是: / p>

Tensor
    [[0.9561102, 0.0028415, 0.0708825, 0.9997129],
     [0.0081552, 0.9410981, 0.0867947, 0.999761 ],
     [0.0346453, 0.1170913, 0.8383155, 0.9999373]]

最后(第四列)将是预测的sepal_width值。预测值应大于1,但是似乎有某种阻止它大于1的情况。

这是原始代码: https://gist.github.com/learncodeacademy/a96d80a29538c7625652493c2407b6be

3 个答案:

答案 0 :(得分:1)

最后一层的激活功能为sigmoid

Sigmoid函数如下所示: Sigmoid

并且您可以看到它限制在0到1的范围内。因此,如果需要其他输出值,则需要相应地调整上一个激活功能。

答案 1 :(得分:1)

您将在最后一层使用S型激活函数来预测sepal_width。 Sigmoid是连续函数,范围在0到1之间。有关更详尽的说明,请参见Wikipedia

如果要预测sepal_width,应尝试使用其他激活功能。有关可用激活功能的列表,您可以检查Tensorflow's API page(这是针对Python版本的,但是对于JavaScript版本应该与此类似)。您可以尝试'softplus''relu'甚至'linear',但我不能说其中任何一个是否适合您的应用程序。尝试并尝试一下,看看哪种方法最好。

答案 2 :(得分:1)

here中的原始代码解决了分类问题。在您的item.sepal_width中添加outputData是没有意义的,因为它不是另一个类。