TensorFlow.js中用于颜色预测的最佳模型类型?

时间:2018-11-21 16:14:04

标签: javascript tensorflow machine-learning tensorflow.js

当我意识到出现问题时,我正在创建色彩预测器。我让模型成功运行,但是预测始终在大约2.5到5.5的相同中值范围内。该模型应该输出与每种颜色相对应的0到8,对于每种颜色,我都有一个偶数个数据点用于训练。我是否可以使用更好的模型,以便将某些东西预测为0或7?我假设不会,因为它认为它们是某种离群值。

这是我的模特

const model = tf.sequential();

const hidden = tf.layers.dense({
  units: 3,
  inputShape: [3] //Each input has 3 values r, g, and b
});
const output = tf.layers.dense({
  units: 1 //only one output (the color that corresponds to the rgb values
    });
model.add(hidden);
model.add(output);

model.compile({
  activation: 'sigmoid',
  loss: "meanSquaredError",
  optimizer: tf.train.sgd(0.005)
});

这是解决我问题的好模型吗?

1 个答案:

答案 0 :(得分:1)

该模型缺乏非线性,因为没有激活函数。 给定一个rgb输入,模型应该以8种可能的值预测最可能的颜色。这是一个分类问题。问题中定义的模型正在进行回归,即正在尝试根据给定的输入预测数值。

对于分类问题,最后一层应预测概率。在这种情况下,softmax激活功能主要用于最后一层。 损失函数应该为categoricalCrossentropybinaryCrossEntropy(如果只有两种颜色可以预测)。

请考虑以下模型来预测红色,绿色和蓝色3种颜色

const model = tf.sequential();
model.add(tf.layers.dense({units: 10, inputShape: [3], activation: 'sigmoid' }));
model.add(tf.layers.dense({units: 10, activation: 'sigmoid' }));
model.add(tf.layers.dense({units: 3, activation: 'softmax' }));

model.compile({ loss: 'categoricalCrossentropy', optimizer: 'adam' });

const xs = tf.tensor([
  [255, 23, 34],
  [255, 23, 43],
  [12, 255, 56],
  [13, 255, 56],
  [12, 23, 255],
  [12, 56, 255]
]);

// Labels
const label = ['red', 'red', 'green', 'green', 'blue', 'blue']
const setLabel = Array.from(new Set(label))
const ys = tf.oneHot(tf.tensor1d(label.map((a) => setLabel.findIndex(e => e === a)), 'int32'), 3)

// Train the model using the data.
  model.fit(xs, ys, {epochs: 100}).then((loss) => {
  const t = model.predict(xs);
  pred = t.argMax(1).dataSync(); // get the class of highest probability
  labelsPred = Array.from(pred).map(e => setLabel[e])
  console.log(labelsPred)
}).catch((e) => {
  console.log(e.message);
})
<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.13.3/dist/tf.min.js"> </script>
  </head>

  <body>
  </body>
</html>