如何从特定索引获取元素,其中索引是使用TensorFlow.js的标量张量?

时间:2020-08-02 18:56:43

标签: javascript tensorflow tensorflow.js

我最近完成了一门课程,其中我们使用了较旧版本的TensorFlow.js,并且在张量(不仅仅是缓冲区)上有一个有用的方法:.get()。由于已将其删除,因此我必须使用其他解决方案来创建简化的学习率优化,在其中将先前的成本与新的成本进行比较,如果先前的成本较大,则提高学习率,否则降低学习率。 Cost始终是一个标量张量,我将前一个与新的费用堆叠在一起,用.argMax()获得较大的索引,然后从我的“常数”张量中获取项目,该张量仅存储两个值,将学习率乘以.argMax()的结果。

一个例子是:

let learningRate = tf.tensor(1);

const prevCost = tf.tensor(1);
const nextCost = tf.tensor(2);

const modifiers = tf.tensor([1.05, 0.5]);

const bigger = tf.stack([prevCost, nextCost]).argMax(); // 1

const modifier = modifiers.get(// if it would still exist
  bigger
); // 0.5

learningRate = learningRate.mul(modifier); // 1 * 0.5 = 0.5

但是不幸的是,.get()不再存在,但是,应该有一种方法可以做到这一点。

1 个答案:

答案 0 :(得分:1)

tf.slice可以按照说明的here

使用
tensor.slice([...cordinates], 1)