当达到特定的损失和准确性值时,如何停止tflearn训练时期或迭代?

时间:2019-01-09 10:45:27

标签: python machine-learning tflearn

我有一个使用tflearn库训练的模型,我使用深度神经网络(DNN)来做到这一点。我们可以在这里看到更多(http://tflearn.org/models/dnn/

下面是我的代码:

# Build neural network
net = tflearn.input_data(shape=[None, len(train_x[0])])
net = tflearn.fully_connected(net, 8)
net = tflearn.fully_connected(net, 8)
net = tflearn.fully_connected(net, len(train_y[0]), activation='softmax')
net = tflearn.regression(net)

# Define model and setup tensorboard
model = tflearn.DNN(net, tensorboard_dir='tflearn_logs', best_val_accuracy=0.91)
# Start training (apply gradient descent algorithm)
model.fit(train_x, train_y, n_epoch=350, batch_size=8, show_metric=True)
model.save('model.tflearn')

运行该代码时,我会得到一些像这样的值,直到纪元结束为止:

Training Step: 5083  | total loss: 0.31890 | time: 0.302s
| Adam | epoch: 085 | loss: 0.31890 - acc: 0.8948 -- iter: 344/474
Training Step: 20999  | total loss: 0.08880 | time: 0.366s
....
Training Step: 11279  | total loss: 0.10708 | time: 0.419s
| Adam | epoch: 188 | loss: 0.10708 - acc: 0.9556 -- iter: 472/474
Training Step: 11280  | total loss: 0.12302 | time: 0.425s
| Adam | epoch: 188 | loss: 0.12302 - acc: 0.9351 -- iter: 474/474
....
| Adam | epoch: 350 | loss: 0.08880 - acc: 0.9503 -- iter: 472/474
Training Step: 21000  | total loss: 0.08863 | time: 0.373s
| Adam | epoch: 350 | loss: 0.08863 - acc: 0.9553 -- iter: 474/474

任何人都知道每次损失和准确性达到特定值时如何停止训练?假设损失0.05,准确性为0.95。 预先感谢

1 个答案:

答案 0 :(得分:1)

通过作为fit方法的参数给出的回调实例使用Early Stopping,如此处所述:

http://mckinziebrandon.me/TensorflowNotebooks/2016/11/20/early-stopping.html

当精度达到0.95时,这样的事情应该可以停止训练

<#@ template debug="false" hostspecific="false" language="C#" #>
<#@ assembly name="System.Core" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ output extension=".cs" #>

partial class Program {
  <# for(int i = 2; i < 16; i++)
    {
      Write("public void HandleComponents<");
      for(int j = 1; j<=i ;j++) {
      if (j>1) {
            Write(",");
        }
            Write("T{0}",j);
      }
        WriteLine(">()");
        for(int j = 1; j<= i; j++) {
          WriteLine("where T{0} : IComponent",j);
        }
        WriteLine("{");
        for(int j = 1; j<= i; j++) {
          WriteLine("HandleComponent<T{0}>();",j);
        }
        WriteLine("}");
  } #>
}