带有EarlyStopping和Training Logs的TensorflowJS不起作用

时间:2019-09-26 11:43:14

标签: tensorflow.js

当我们同时定义提前停止和训练日志功能时,TensorflowJS似乎不起作用。上面的示例取自TensorflowJS文档,我只是添加了onTrainBegin回调-但失败。

const model = tf.sequential();
model.add(tf.layers.dense({
  units: 3,
  activation: 'softmax',
  kernelInitializer: 'ones',
  inputShape: [2]
}));
const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
model.compile(
    {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});

const onTrainBegin = function onTrainBegin(logs){
     console.log("onTrainBegin");
}


// Without the EarlyStopping callback, the val_acc value would be:
//   0.5, 0.5, 0.5, 0.5, ...
// With val_acc being monitored, training should stop after the 2nd epoch.
const history = await model.fit(xs, ys, {
  epochs: 10,
  validationData: [xsVal, ysVal],
  callbacks: [onTrainBegin, tf.callbacks.earlyStopping({monitor: 'val_acc'})]
});

// Expect to see a length-2 array.
console.log(history.history.val_acc);

此代码产生错误消息:

  

发生错误this.getMonitorValue不是函数

https://js.tensorflow.org/api/latest/#callbacks.earlyStopping

2 个答案:

答案 0 :(得分:0)

您正在混合不同的东西。 OntrainBegin指定何时执行回调函数,而tf.callbacks.earlyStopping({monitor: 'val_acc'})是函数

(async() => {
const model = tf.sequential();
model.add(tf.layers.dense({
  units: 3,
  activation: 'softmax',
  kernelInitializer: 'ones',
  inputShape: [2]
}));
const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
model.compile(
    {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});

const  onTrainBegin = logs => {
     console.log("onTrainBegin");
}


// Without the EarlyStopping callback, the val_acc value would be:
//   0.5, 0.5, 0.5, 0.5, ...
// With val_acc being monitored, training should stop after the 2nd epoch.
const history = await model.fit(xs, ys, {
  epochs: 10,
  validationData: [xsVal, ysVal],
  callbacks: [{
    onEpochEnd: onTrainBegin()
  }, tf.callbacks.earlyStopping({monitor: 'val_acc'})
 ]
});

// Expect to see a length-2 array.
console.log(history.history.val_acc);
})()
<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"> </script>
  </head>

  <body>
  </body>
</html>

答案 1 :(得分:0)

以下是在 tfjs 中的模型训练中使用 earlyStoppingtf.CustomCallback 的工作代码示例。

await model.fitDataset(convertedTrainingData, 
                         {epochs: 50,
                            validationData: convertedTestingData,
                          callbacks:[
                              new tf.CustomCallback({
                                onEpochEnd: async(epoch, logs) =>{
                                    acc = logs.acc;
                                    console.log("Epoch: " + epoch 
                                              + " Loss: " + logs.loss.toFixed(4) 
                                              + " Accuracy: " + logs.acc.toFixed(4) 
                                              + " Val Loss: " + logs.val_loss.toFixed(4) 
                                              + " Val Accuracy: " + logs.val_acc.toFixed(4));
                                            },
                                onTrainEnd: async() =>{
                                    console.log("training done");
                                    if (acc>0.4) {
                                        repeat = false;
                                        console.log(repeat);
                                    }
                                }
                            }),
                            tf.callbacks.earlyStopping({monitor: 'loss'})
                        ]});