我正在关注TensorFlow的Generative Adversarial Network教程。本教程使用MNIST数据集来训练模型。我想减小输入的大小,以便我的程序运行得更快,但不知道如何获取我正在使用的MNIST数据集的子集。下面是我用来提取数据集的代码:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/")
答案 0 :(得分:0)
有一种方法
mnist.next_batch(batchsize)
从列车集中提取长度batchsize的随机样本。
如果您不想要随机的内容,可以通过
访问它们x = mnist.train.images[start_batch:end_batch]
y = mnist.train.labels[start_batch:end_batch]
或类似于mnist.test
的测试集。