为model.fit_generator()随机抽取样本

时间:2018-12-25 22:24:31

标签: python-3.x tensorflow keras deep-learning

我的数据集太大,无法放入内存。因此,我尝试为model.fit_generator()创建以下自定义函数,该函数从HDF5文件中抽取随机样本以及标签。我的尝试如下。

train_data=11400
batch_size=64
steps_per_epoch=int(train_data/batch_size)

def train_generator(batch_size, steps_per_epoch):
    random_numbers_range= list(range(0, train_data-batch_size))
    train_dir='.../train.h5'
    train_file=h5py.File(train_dir, 'r')
    while True:
      for batch in range(steps_per_epoch):
        start_index=random.choice(list(random_numbers_range))
        end_index=start_index+batch_size
        random_numbers_range.remove(start_index)
#         start_index=batch*batch_size
#         end_index=(batch+1)*batch_size
        x_train=train_file['Data'][start_index:end_index]
        y_train=train_file['Label'][start_index:end_index]
        yield x_train, y_train
      random_numbers_range= list(range(0, train_data-batch_size))

def vaidation_generator():
    val_dir='.../test.h5'
    val_file=h5py.File(val_dir, 'r')
    x_val=val_file['Data'][140:240]
    y_val=val_file['Label'][140:240]
    y_val=y_val.reshape(-1)
    y_train=y_val.tolist()
    return x_val, y_val

用于培训

training=train_generator(batch_size, steps_per_epoch) 
x_val, y_val=vaidation_generator()
validation=(x_val, y_val)
model.fit_generator(training, steps_per_epoch=steps_per_epoch, epochs=200, validation_data=validation, validation_steps=None)

请查看是否可以,而且我没做错。

0 个答案:

没有答案