我正在按照MNIST教程here来识别手写字符。
我能够毫无问题地加载和识别手写数字,但是现在我想在新图像上再次训练模型(特别是一次)。
由于某种原因,当我选择等于1的训练量时,我所有的预测都变为NaN。
如果我选择一个值> = 2,就可以正常工作。
火车功能:
async function train(model, data)
{
const TRAIN_DATA_SIZE = 1; // WHEN THIS IS 1, CAUSES PREDICT TO OUTPUT NaN
const [trainXs, trainYs] = tf.tidy(() =>
{
const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
return [
d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
d.labels
];
});
console.log(trainXs.dataSync());
console.log(trainYs.dataSync());
return model.fit(trainXs, trainYs);
}
nextTrainBatch
的代码为here。
用于预测的示例输出:
currentTensor = tf.tensor2d(inputs, [1, PIXELSSQUARED]);
const output = model.predict(currentTensor.reshape([1, 28, 28, 1]));
const prediction_value = Array.from(output.argMax(1).dataSync());
console.log(output.dataSync());
训练大小为2或更大时:
Float32Array(10) [3.308702423154841e-9, 5.89648436744028e-8, 0.00005333929220796563, 0.8063259720802307, 7.401082784824764e-13, 1.1464327087651327e-7, 6.5924318955190575e-12, 0.1936144232749939, 0.000004253268798493082, 0.000001676815713835822]
训练大小为1时:
Float32Array(10) [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN]
答案 0 :(得分:0)
该模型达到数值不稳定。使用SGD之类的优化程序可能会有所帮助。但是,批量大小为1实际上不是一个好主意,因为模型可能会在最佳值附近波动
我希望用户在模型做出预测后选择正确的值,例如进行预测,选择正确的输出,根据此信息进行重新训练
如果要进一步训练,则需要具有与模型inputShape匹配的数据。这样就可以收集预测值和用户选择的结果,并可以用于进一步训练模型
// the model has been trained
y = model.predict(x) // predict
假设用户将验证结果y。进一步训练
model.fit(x, y)
周期继续