如何依次加载MNIST数据进行训练?

时间:2020-07-15 18:03:28

标签: python tensorflow mnist

我现在正在使用MNIST和彩色数字版本来训练我的SemGAN模型。我已经加载并预处理了数据集。

x_train = x_train[0:10000]
y_train = y_train[0:10000]
x_test = x_test[0:10]
y_test = y_test[0:10]
x_train = x_train.reshape(-1, 28, 28, 1).astype(np.float32)
x_test = x_test.reshape(-1, 28, 28, 1).astype(np.float32) 

现在,我必须将每个重塑后的数据(28、28、1)顺序地(按批次或随机分组)输入模型,该怎么做?谁能帮我吗?

1 个答案:

答案 0 :(得分:0)

您需要定义一个神经网络架构。这是一个简单的CNN,仅用于说明目的,您可以针对您的案例研究GAN架构,并进行构建或提出一个单独的问题。

mnist_model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(16,[3,3], activation='relu',
                            input_shape=(None, None, 1)),
    tf.keras.layers.Conv2D(16,[3,3], activation='relu'),
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(10)
])

然后,您需要编译损失,优化器,指标等。这又是分类的基本编译器。

mnist_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                    metrics=[tf.keras.metrics.CategoricalAccuracy()])

现在,继续您的问题。一切设置完成后,您只需使用模型的拟合函数

mnist_model.fit(x_train, y_train, batch_size=64, epochs=1, shuffle=True, validation_split=0.05)

批次大小为64,因此拟合以64的间隔遍历所有批次,直到完成为止。有一个用于验证的拆分参数,以检查模型是否在训练中(而不是过度训练)-这样,结合64个批处理大小,将存在ceil(60000 * 0.95 / 64)迭代,其中60000来自于训练集(您选择了10000)。在这种情况下(默认情况下),该呼叫将在每个纪元随机调整训练集。最后,纪元为1,但实际上您可能希望更高。