如何在Tensorflow.js中使用model.fit的classWeight选项?

时间:2019-05-08 12:30:29

标签: javascript tensorflow.js

在我的训练集中,我的课程没有得到平等的体现。因此,我试图使用model.fit函数的classWeight选项。引用文档:

  

classWeight ({[classIndex: string]: number})可选的字典映射类索引(整数)到权重(浮点数),以应用于训练过程中该类样本的模型损失。这可能有助于告诉模型“更多关注”来自代表性不足的类的样本。

所以这似乎正是我要寻找的东西,但是我无法弄清楚如何使用它。 classIndex在文档中没有其他地方,而且我不明白如何将类编写为索引。由于我的代码是一键编码,因此我什至尝试使用该索引(0用于[1,0,0]等),但是它没有用(请参见下文)。

这是一个最小的代码示例:

const xs = tf.tensor2d([ [1,2,3,4,5,6], /* ... */ ]); // shape: [1000,6]
const ys = tf.tensor2d([ [1,0,0], /* ... */ ]);       // shape: [1000,3] (one-hot encoded classes)

const model = tf.sequential();
// ... some layers

await model.fit(xs, ys, {
    /* ... */
    classWeight: {
        0: 10000, // doesn't do anything
        1: 1,
        2: 1,
    },
});

我希望我的模型“更多地关注”我的第一堂课([1,0,0]),因此会更频繁地对其进行预测。但是看来Tensorflow.js只是忽略了该参数。

0 个答案:

没有答案