将TensorFlow中的张量值转换为常规Javascript数组

时间:2019-02-07 20:57:37

标签: javascript arrays tensorflow.js vector-multiplication

我在TensorFlow.js框架中的两个一维数组(a,b)上使用了externalProduct函数,但是我发现很难以常规javascript格式获取结果张量的值。

即使使用.dataSync和Array.from()之后,我仍然无法获得预期的输出格式。两个1D数组之间的外部乘积应给出一个2D数组,但我却得到1D数组。

const a = tf.tensor1d([1, 2]);
const b = tf.tensor1d([3, 4]);
const tensor = tf.outerProduct(b, a);
const values = tensor.dataSync();
const array1 = Array.from(values);

console.log(array1);

预期结果是array1 = [[3,6],[4,8]],但是我得到 array1 = [3,6,4,8]

4 个答案:

答案 0 :(得分:4)

版本<15

tf.datatf.dataSync的结果始终是一个扁平数组。但是人们可以使用张量的形状通过mapreduce获得多维数组。

const x = tf.tensor3d([1, 2 , 3, 4 , 5, 6, 7, 8], [2, 4, 1]);

x.print()

// flatten array
let arr = x.dataSync()

//convert to multiple dimensional array
shape = x.shape
shape.reverse().map(a => {
  arr = arr.reduce((b, c) => {
  latest = b[b.length - 1]
  latest.length < a ? latest.push(c) : b.push([c])
  return b
}, [[]])
console.log(arr)
})
<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.14.1"> </script>
  </head>

  <body>
  </body>
</html>

从0.15版开始

一个人可以使用tensor.array()tensor.arraySync()

答案 1 :(得分:1)

从tfjs版本0.15.1开始,您可以使用await tensor.array()来获取嵌套数组。

答案 2 :(得分:0)

您可以拿起values并做

const values = [3, 6, 4, 8];

let array1 = []

for (var i = 0; i < values.length; i += 2) {
  array1.push([values[i], values[i + 1]])
}

console.log(array1)

答案 3 :(得分:-1)

const a_const = [1, 2, 3, 4, 5, 6, 7, 8, 9];
const b_const = [10, 11, 12, 13]
const am = tf.tensor1d(a_const);
const bm = tf.tensor1d(b_const);

const tensor = tf.outerProduct(bm, am);
console.log(tensor.print())

使用@edkeveked多维数组转换函数,提出了一种粗略的方法。

此例的预期输出为

//Result from TensorFlow.js
Tensor
    [[10, 20, 30, 40, 50, 60, 70, 80 , 90 ],
     [11, 22, 33, 44, 55, 66, 77, 88 , 99 ],
     [12, 24, 36, 48, 60, 72, 84, 96 , 108],
     [13, 26, 39, 52, 65, 78, 91, 104, 117]]

//Results using the crude approach
ar_first [ [ 10, 20, 30, 40, 50, 60, 70, 80, 90 ],
  [ 11, 22, 33, 44, 55, 66, 77, 88, 99 ],
  [ 12, 24, 36, 48, 60, 72, 84, 96, 108 ],
  [ 13, 26, 39, 52, 65, 78, 91, 104, 117 ] ]