为什么model.fit需要二维张量?为什么model.predict不接受标量张量?

时间:2019-11-07 09:36:58

标签: javascript multidimensional-array neural-network tensor tensorflow.js

当我注意到model.fit必须接受两个参数(输入和输出以及一些配置)时,我正在学习TensorFlow.js。但是输入是二维张量,其编写如下:

let input = tf.tensor2d([1, 2, 3, 4, 5], [5, 1])

这看起来非常像一维张量,写如下:

let input = tf.tensor1d([1, 2, 3, 4, 5])

由于二维张量实际上是一个5乘1,我决定将其替换为一维张量。但是,这完全停止了程序的工作。那么是否有某种类型的代码说输入必须是二维的?如果可以,为什么?

关于多维张量,我还注意到model.predict不能接受零维张量或标量。见下文

Working Code:

model.predict(tf.tensor1d([6]))

Not Working Code:

model.predict(tf.scalar(6))

如果任何人都可以澄清这些限制背后的原因,将不胜感激。

1 个答案:

答案 0 :(得分:0)

2D张量不是1D张量。 tf.tensor2d([1, 2, 3, 4, 5], [5, 1])不是tf.tensor1d([1, 2, 3, 4, 5])。一个可以转换为另一个,但这并不意味着它们相等。

model.fit将张量或等级2或更大作为参数。该张量可以视为元素的数组,其形状已指定给模型的输入。模型的inputShape至少是等级1,这使model.fit参数至少为2(1 + 1始终是inputShape的等级+ 1)。

由于model.fit和model.predict将具有相同秩的张量作为参数,因此出于上述相同的原因,model.predict参数是秩2或更高的张量。

但是,model.predict(tf.tensor1d([6]))有效。这样做是因为在内部,tensorflow.js会将一维张量转换为二维张量。形状为[6]的初始张量将转换为形状[6,1]的张量。

model.predict(tf.tensor1d([6])) 
// will work because it is a 1D tensor 
// and only in the case where the model first layer inputShape is [1]

model.predict(tf.tensor2d([[6]])) 
// will also work
// One rank higher than the inputShape and of shape [1, ...InputShape]

model.predict(tf.scalar(6)) // will not work

const model = tf.sequential(
    {layers: [tf.layers.dense({units: 1, inputShape: [1]})]});
model.predict(tf.ones([3])).print(); // works
model.predict(tf.ones([3, 1])).print(); // works
<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"> </script>
  </head>

  <body>
  </body>
</html>

const model = tf.sequential(
    {layers: [tf.layers.dense({units: 1, inputShape: [2]})]});
model.predict(tf.ones([2, 2])).print(); // works
model.predict(tf.ones([2])).print(); // will not work
   // because [2] is converted to [2, 1]
   // whereas the model is expecting an input of shape [b, 2] with b an integer
<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"> </script>
  </head>

  <body>
  </body>
</html>