这是在model.fit(),model.train_on_batch(),model.fit_generator()之间最合适的训练方法

时间:2018-12-25 12:06:06

标签: python-3.x tensorflow keras deep-learning computer-vision

我有一个训练数据集,其中包含分辨率为(512 * 512 * 1)的600张图像,分为2类(每类300张图像)。使用一些扩充技术,我已将数据集增加到10000张图像。经过以下预处理步骤

all_images=np.array(all_images)/255.0
all_images=all_images.astype('float16')
all_images=all_images.reshape(-1,512,512,1)
saved these images to H5 file.

我将AlexNet架构用于分类,具有3个卷积,3个重叠的最大池层。 我想知道以下哪种情况最适合在内存大小限制为12GB的情况下使用Google Colab进行培训。

1. model.fit(x,y,validation_split=0.2)
# For this I have to load all data into memory and then applying an AlexNet to data will simply cause Resource-Exhaust error.

2. model.train_on_batch(x,y)
# For this I have written a script which randomly loads the data batch-wise from H5 file into the memory and train on that data. I am confused by the property of train_on_batch() i.e single gradient update. Do this will affect my training procedure or will it be same as model.fit().

3. model.fit_generator() 
# giving the original directory of images to its data_generator function which automatically augments the data and then train using model.fit_generator(). I haven't tried this yet. 

请指导我,这对我来说是最好的方法。我已经阅读了许多关于model.fit(),model.train_on_batch()和model.fit_generator()的答案HereHereHere,但我仍然感到困惑。

1 个答案:

答案 0 :(得分:1)

model.fit-如果您将数据加载为numpy-array并进行训练而不进行扩充,则适用。 model.fit_generator-如果您的数据集太大而无法容纳在内存中,或者\并且您想即时应用增强。 model.train_on_batch-较少见,通常在一次训练多个模型(例如GAN)时使用