限制MNIST训练数据的大小

时间:2019-07-23 15:33:13

标签: python tensorflow mnist

我刚刚开始学习python和TensorFlow,并正在尝试各种神经网络和MNIST数据。我想做的一个实验是查看训练集的大小如何影响性能。当前,训练集中似乎有55000个输入/输出对。我想以某种方式将培训限制为仅使用前1000个左右,但不知道如何实现。

我当前的训练功能如下:

def do_training():
    print("Train entry")
    for i in range(2000):

        batch_of_training_inputs, batch_of_training_labels = mnist.train.next_batch(100)

        sess.run(train_step, feed_dict={generic_image_data_struct: batch_of_training_inputs, target_for_output_struct: batch_of_training_labels })

有没有类似的东西...

mnist.train.next_batch(100, BUT_ONLY_FROM_FIRST(1000))

仅供参考,此代码让我很高兴:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

2 个答案:

答案 0 :(得分:1)

您可以做的一件简单的事情就是增加验证数据集的大小。 MNIST包含60,000张图像,因此,如果您只想训练1,000张,则可以:

mnist = input_data.read_data_sets(train_dir, one_hot=True, validation_size=59000)

答案 1 :(得分:0)

通过一点点黑客攻击,我认为这可能有效。尽管我确实将来不建议使用此解决方案,因为它依赖于League league = new League() { leagueinfo = JsonConvert.DeserializeObject<LeagueInfo[]>(data) }; 方法的内部实现,该实现以某种方式进行。为了进行快速实验,可能没事。

DataSet.__init__