如何在Tensorflow 2中实现小批量梯度下降?

时间:2020-09-01 12:04:28

标签: python numpy tensorflow keras tensorflow-datasets

我对机器学习和Tensorflow还是比较陌生,我想尝试在MNIST数据集上实现小批量梯度下降。但是,我不确定应该如何实施。

(附带说明:训练图像(28像素x 28像素)和标签存储在Numpy数组中)

此刻,我可以看到两种不同的实现方式:

  1. 我的训练图像位于[60000,28,28]的Numpy数组中。将其重塑为[25(批处理数量),2400(批处理中的图像数量),28,28],然后使用for循环调用每个批处理并将其传递给model.compile()方法。我唯一担心使用此方法的是for循环本质上很慢,而矢量化的实现会更快。

  2. 将图像和标签合并到一个tensorflow数据集对象中,然后调用Dataset.batch()方法和Dataset.prefetch()方法,然后将数据传递给model.compile()方法。唯一的问题是我的数据不会保留为Numpy数组,我觉得它比tensorflow数据集对象具有更大的灵活性。

这两种方法中哪一种是最佳实施方式,还是我不知道的第三种最佳方式?

1 个答案:

答案 0 :(得分:1)

Keras的model.fit方法有一个内置的batch_size参数(因为您用keras标记了这个问题,所以我假设您正在使用它)。我相信这可能是实现您所寻找的最佳最佳方法。