输入大小为1的训练会在随后的预测中导致NaN

时间:2019-12-14 00:04:56

标签: javascript tensorflow keras mnist tensorflow.js

我正在按照MNIST教程here来识别手写字符。

我能够毫无问题地加载和识别手写数字,但是现在我想在新图像上再次训练模型(特别是一次)。

由于某种原因,当我选择等于1的训练量时,我所有的预测都变为NaN。

如果我选择一个值> = 2,就可以正常工作。

火车功能:

async function train(model, data)
{
    const TRAIN_DATA_SIZE = 1; // WHEN THIS IS 1, CAUSES PREDICT TO OUTPUT NaN

    const [trainXs, trainYs] = tf.tidy(() =>
    {
        const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
        return [
            d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
            d.labels
        ];
    });

    console.log(trainXs.dataSync());
    console.log(trainYs.dataSync());

    return model.fit(trainXs, trainYs);
}

nextTrainBatch的代码为here

用于预测的示例输出:

currentTensor = tf.tensor2d(inputs, [1, PIXELSSQUARED]);

const output = model.predict(currentTensor.reshape([1, 28, 28, 1]));
const prediction_value = Array.from(output.argMax(1).dataSync());
console.log(output.dataSync());

训练大小为2或更大时:

Float32Array(10) [3.308702423154841e-9, 5.89648436744028e-8, 0.00005333929220796563, 0.8063259720802307, 7.401082784824764e-13, 1.1464327087651327e-7, 6.5924318955190575e-12, 0.1936144232749939, 0.000004253268798493082, 0.000001676815713835822]

训练大小为1时:

Float32Array(10) [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN]

1 个答案:

答案 0 :(得分:0)

该模型达到数值不稳定。使用SGD之类的优化程序可能会有所帮助。但是,批量大小为1实际上不是一个好主意,因为模型可能会在最佳值附近波动

  

我希望用户在模型做出预测后选择正确的值,例如进行预测,选择正确的输出,根据此信息进行重新训练

如果要进一步训练,则需要具有与模型inputShape匹配的数据。这样就可以收集预测值和用户选择的结果,并可以用于进一步训练模型

// the model has been trained
y = model.predict(x) // predict

假设用户将验证结果y。进一步训练

model.fit(x, y)

周期继续