我的数据集太大,无法放入内存。因此,我尝试为model.fit_generator()
创建以下自定义函数,该函数从HDF5文件中抽取随机样本以及标签。我的尝试如下。
train_data=11400
batch_size=64
steps_per_epoch=int(train_data/batch_size)
def train_generator(batch_size, steps_per_epoch):
random_numbers_range= list(range(0, train_data-batch_size))
train_dir='.../train.h5'
train_file=h5py.File(train_dir, 'r')
while True:
for batch in range(steps_per_epoch):
start_index=random.choice(list(random_numbers_range))
end_index=start_index+batch_size
random_numbers_range.remove(start_index)
# start_index=batch*batch_size
# end_index=(batch+1)*batch_size
x_train=train_file['Data'][start_index:end_index]
y_train=train_file['Label'][start_index:end_index]
yield x_train, y_train
random_numbers_range= list(range(0, train_data-batch_size))
def vaidation_generator():
val_dir='.../test.h5'
val_file=h5py.File(val_dir, 'r')
x_val=val_file['Data'][140:240]
y_val=val_file['Label'][140:240]
y_val=y_val.reshape(-1)
y_train=y_val.tolist()
return x_val, y_val
用于培训
training=train_generator(batch_size, steps_per_epoch)
x_val, y_val=vaidation_generator()
validation=(x_val, y_val)
model.fit_generator(training, steps_per_epoch=steps_per_epoch, epochs=200, validation_data=validation, validation_steps=None)
请查看是否可以,而且我没做错。