使用整个MNIST数据集(60000张图像)训练张量流需要多少次迭代?

时间:2016-08-11 15:44:53

标签: python tensorflow mnist

MNIST集由60,000个训练集图像组成。在训练我的Tensorflow时,我想运行火车步骤来训练整个训练集的模型。 Tensorflow网站上的深度学习示例使用20,000次迭代,批次大小为50(总计1,000,000批次)。当我尝试超过30,000次迭代时,我的数字预测失败(对所有手写数字预测为0)。我的问题是,我应该使用批量大小为50的迭代次数来训练具有整个MNIST集的张量流模型?

self.mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
for i in range(FLAGS.training_steps):
    batch = self.mnist.train.next_batch(50)
    self.train_step.run(feed_dict={self.x: batch[0], self.y_: batch[1], self.keep_prob: 0.5})
    if (i+1)%1000 == 0:
       saver.save(self.sess, FLAGS.checkpoint_dir + 'model.ckpt', global_step = i)

4 个答案:

答案 0 :(得分:2)

我认为这取决于您的停止标准。您可以在损失没有改善时停止培训,或者您可以拥有验证数据集,并在验证准确性不再提高时停止培训。

答案 1 :(得分:2)

通过机器学习,您往往会遇到收益递减的严重情况。例如,这里是我的一个CNN的准确性列表:

Epoch 0 current test set accuracy :  0.5399
Epoch 1 current test set accuracy :  0.7298
Epoch 2 current test set accuracy :  0.7987
Epoch 3 current test set accuracy :  0.8331
Epoch 4 current test set accuracy :  0.8544
Epoch 5 current test set accuracy :  0.8711
Epoch 6 current test set accuracy :  0.888
Epoch 7 current test set accuracy :  0.8969
Epoch 8 current test set accuracy :  0.9064
Epoch 9 current test set accuracy :  0.9148
Epoch 10 current test set accuracy :  0.9203
Epoch 11 current test set accuracy :  0.9233
Epoch 12 current test set accuracy :  0.929
Epoch 13 current test set accuracy :  0.9334
Epoch 14 current test set accuracy :  0.9358
Epoch 15 current test set accuracy :  0.9395
Epoch 16 current test set accuracy :  0.942
Epoch 17 current test set accuracy :  0.9436
Epoch 18 current test set accuracy :  0.9458

正如您所看到的,在大约10个纪元*之后,回报开始下降,但这可能会因您的网络和学习率而异。根据你有多少关键/多少时间来做好的数量会有所不同,但我发现20是一个合理的数字

*我一直使用epoch这个词来表示整个数据集的运行,但我不知道该定义的准确性,这里的每个时期是〜429个训练步骤,批量为128个。

答案 2 :(得分:0)

您可以使用类似 no_improve_epoch 的内容并将其设置为让我们说3.它只是意味着如果在3次迭代中没有> 1%的改善,那么停止迭代。

no_improve_epoch= 0
        with tf.Session() as sess:
            sess.run(cls.init)
            if cls.config.reload=='True':
                print(cls.config.reload)
                cls.logger.info("Reloading the latest trained model...")
                saver.restore(sess, cls.config.model_output)
            cls.add_summary(sess)
            for epoch in range(cls.config.nepochs):
                cls.logger.info("Epoch {:} out of {:}".format(epoch + 1, cls.config.nepochs))
                dev = train
                acc, f1 = cls.run_epoch(sess, train, dev, tags, epoch)

                cls.config.lr *= cls.config.lr_decay

                if f1 >= best_score:
                    nepoch_no_imprv = 0
                    if not os.path.exists(cls.config.model_output):
                        os.makedirs(cls.config.model_output)
                    saver.save(sess, cls.config.model_output)
                    best_score = f1
                    cls.logger.info("- new best score!")

                else:
                    no_improve_epoch+= 1
                    if nepoch_no_imprv >= cls.config.nepoch_no_imprv:
                        cls.logger.info("- early stopping {} Iterations without improvement".format(
                            nepoch_no_imprv))
                        break

Sequence Tagging GITHUB

答案 3 :(得分:0)

我发现,使用MNIST,每个时期对3,833张图像的训练(在56,167 because 60k ** 0.75上进行验证的时间刚好超过3.833)在500个时代之前趋于收敛。 “收敛”是指批量大小为16的连续50个训练周期的验证损失不会减少;有关通过tf.keras使用提前停止的示例,请参见this回购;在这种情况下,这对我很重要,因为我正在进行模型搜索,并且没有时间训练很长的单个模型。