在tensorflow.js中对Tensor进行分区或屏蔽或过滤

时间:2018-07-06 22:39:19

标签: tensorflow tensorflow.js

我有两个长度相同的张量datagroupIds。我想通过data中的相应值将groupId分成几组。例如,

const data = tf.tensor([1,2,3,4,5]);
const groupIds = tf.tensor([0,1,1,0,0]);
// expected result: [tf.tensor([1,4,5]), tf.tensor([2,3])]

在Tensorflow中,tf.dynamic_partition正是这样做的。 Tensorflow.js似乎没有类似的方法。我也将遮罩或过滤作为解决方法,但它们也不存在。有谁知道如何实现这一目标?

1 个答案:

答案 0 :(得分:0)

要对张量进行分区,您可以首先在ids张量上进行迭代,以获取要创建的张量的数量及其应包含的元素的索引。该信息可以存储在一个对象中,其中键是ids数组中的分区号,而值是索引数组。

const data = tf.tensor([6,2,8,4,5]);
const ids = tf.tensor([0,1,1,0,2]);

const data2 = tf.tensor([[6,2],[8,4], [5, 4], [6, 5]]);
const ids2 = tf.tensor([0,1,1,0]);

const filterT = (t, p) => {
  t.print()
  p.print()
  const l = p.unstack().reduce((a, b, i) => {
  const v = b.dataSync()[0]
  if (Object.keys(a).includes(v.toString())) {
    a[v].push(i)
  } else {
    a[v] = [i]
  }
  return a
}, {})

  const r = Object.keys(l).map(k => t.gather(tf.tensor1d(l[k], 'int32')))
  r.forEach(e => e.print())
}

filterT(data, ids)
filterT(data2, ids2)
<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdnjs.cloudflare.com/ajax/libs/tensorflow/0.12.4/tf.js"> </script>
  </head>

  <body>
  </body>
</html>