我正在使用keras模型的功能版本来训练我的自动编码器网络,并且正在使用来自keras的model.fit()函数,如下所示:
model.fit(train_data, train_ground_truth_data, validation_data = (validation_data, validation_ground_truth_data), epochs=1000, batch_size=40, callbacks=callbacks)
我能够以此方式训练模型。这时,我正在使用以下命令立即将数据转换为float32数据类型: train_data = train_data.astype('float32')/ 255。
因为我有大量数据,当将其转换为 float32 时,会导致RAM耗尽,所以我决定将其批量转换并使用fit_generator训练我的模型。以下代码批量加载图像及其基本情况。
def imageLoader(data, ground_truth_data, batch_size):
L = data.shape[0] # there are equal numbers of images in data and its ground_truth_data
#this line is just to make the generator infinite, keras needs that
while True:
batch_start = 0
batch_end = batch_size
while batch_start < L:
limit = min(batch_end, L)
#print(limit)
X = data[batch_start:limit, :].astype('float32') / 255.
Y = ground_truth_data[batch_start:limit, :].astype('float32') / 255.
yield (X,Y) #a tuple with two numpy arrays with batch_size samples
batch_start += batch_size
batch_end += batch_size
我使用以下代码行使用来自keras的fit_generator函数进行训练。
model.fit_generator(imageLoader(train_data,train_ground_truth_data, batch_size), validation_data = imageLoader(validation_data, validation_ground_truth_data, batch_size), epochs=1000, steps_per_epoch = math.ceil(num_samples / batch_size), validation_steps = math.ceil(num_validation_samples / batch_size), callbacks=callbacks)
我想检查这种使用fit_generator()进行训练的方法是否与fit()函数相同的方式,我使用了一个小的数据集,并使用这两种方法对其进行了测试。但是问题是,当我使用fit_generator时,模型无法训练。当我使用imageLoader()函数使用model.fit_generator()函数时,它正在正确地训练。验证损失和训练损失在2个时期后保持不变。
使用fit_generator()这样做不是正确的方法吗?