Tensorflow.Js问题:“错误:get()中的坐标数必须与张量的等级匹配”

时间:2018-10-17 15:38:55

标签: tensorflow.js

这是什么? 我在youtube视频“ 6.3:TensorFlow.js:变量与运算-智能和学习”中学习tensorflow.js。

一切正常,直到我尝试使用此get()为止。

const getRandomInt = (max) => {
return Math.floor(Math.random() * Math.floor(max));
};

   const values = [];
  for (let i = 0; i< 30; i++) {
    values[i] = getRandomInt(10);
  }

const shape = [2, 5, 3];

const matriisi = tf.tensor3d(values, shape, 'int32');

console.log(matriisi.get(3));

Web控制台说:

  

“错误:get()中的坐标数必须与张量的等级匹配”

1 个答案:

答案 0 :(得分:0)

get函数的参数数量应与张量的等级匹配。 您的张量为3或3,这意味着get应该具有3个参数。张量中的第三个元素具有以下坐标:[0, 1, 0]。您宁可使用matriisi.get(0, 1, 0)

通过索引获取元素的另一种方法是使用dataSync()data()以获得可以通过索引访问的类似数组的元素。

const a = tf.randomNormal([2, 5, 3], undefined, undefined, undefined, 3);

const indexToCoords = (index, shape) => {
  const pseudoShape = shape.map((a, b, c) => c.slice(b + 1).reduce((a, b) => a * b, 1));
  let coords = [];
  let ind = index;
  for (let i = 0; i < shape.length; i++) {
    coords.push(Math.floor(ind / pseudoShape[i]));
    ind = ind % pseudoShape[i];
  }
  return coords
}
const coords = indexToCoords(3, [2, 5, 3]);
// get the element of index 3
console.log(a.get(...coords));

// only dataSync will do the trick 
console.log(a.dataSync()[3])
<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.13.0"> </script>
  </head>

  <body>
  </body>
</html>