如何使用TensorFlow.js进行多变量回归

时间:2019-04-02 22:35:37

标签: javascript math regression mathematical-optimization tensorflow.js

我想使用TensorFlow拟合非线性多变量方程。公式如下。适合的参数是a0,a1和a2。自变量是S和R,而F是因变量。下面的代码中分别提供了S,R,F的相应数据,分别为Sdata,Rdata和Fdata。

  

F = a0 + a1 * S + a2 * R

const Sdata = tf.tensor1d([13.8,13.8,20.2,12.1,14.1,29.4,13.7,16.6,18.9,15.5]);
const Fdata = tf.tensor1d([46.7,130.7,78.1,72.2,40.1,78.6,57.4,170.7,80.2,45.2]);
const Rdata = tf.tensor1d([1.5,4.5,2.5,3.0,3.5,3.0,2.5,3.0,3.0,2.5])

const a0 = tf.scalar(Math.random()).variable();
const a1 = tf.scalar(Math.random()).variable();
const a2 = tf.scalar(Math.random()).variable();

const fun = (r,s) => a2.mul(r).add(a1.mul(s)).add(a0)
const cost = (pred, label) => pred.sub(label).square().mean();

const learningRate = 0.01;
const optimizer = tf.train.sgd(learningRate);

// Train the model.
for (let i = 0; i < 800; i++) {
    optimizer.minimize(() => cost(fun(Rdata,Sdata), Fdata));
}

如我的代码所示,我假设函数“ fun”可以带有两个自变量。我得到的是NaN,而不是a0 = -6.6986,a1 = 0.8005和a2 = 25.2523。

这是否意味着无法在tensorflow.js中拟合多变量函数?我不这么认为。我将不胜感激。

1 个答案:

答案 0 :(得分:0)

由于学习率高,该模型一直在振荡以找到最佳参数。实际上,参数不断增加到Infinity。

调整学习速率将使模型能够找到最佳参数。在这种情况下,0.001似乎可以提供良好的结果。如果要提高模型的准确性,可以考虑将所有输入数据规格化为相同的数量级-0到1之间

const Sdata = tf.tensor1d([13.8,13.8,20.2,12.1,14.1,29.4,13.7,16.6,18.9,15.5]);
const Fdata = tf.tensor1d([46.7,130.7,78.1,72.2,40.1,78.6,57.4,170.7,80.2,45.2]);
const Rdata = tf.tensor1d([1.5,4.5,2.5,3.0,3.5,3.0,2.5,3.0,3.0,2.5])

const a0 = tf.scalar(Math.random()).variable();
const a1 = tf.scalar(Math.random()).variable();
const a2 = tf.scalar(Math.random()).variable();

const fun = (r,s) => a2.mul(r).add(a1.mul(s)).add(a0)
const cost = (pred, label) => pred.sub(label).square().mean();

const learningRate = 0.001;
const optimizer = tf.train.sgd(learningRate);

// Train the model.
for (let i = 0; i < 800; i++) {
    console.log("training")
    optimizer.minimize(() => cost(fun(Rdata,Sdata), Fdata));
}

console.log(`a: ${a0.dataSync()}, b: ${a1.dataSync()}, c: ${a2.dataSync()}`);

const preds = fun(Rdata,Sdata).dataSync();
preds.forEach((pred, i) => {
   console.log(`x: ${i}, pred: ${pred}`);
});
<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"> </script>
  </head>

  <body>
  </body>
</html>