我想在Keras的每(X_train, y_train)
个时期传递另一个训练数据集N
,其中(X_train, y_train)
是通过蒙特卡罗模拟获得的。
在伪代码中,它将通过以下方式完成:
for i in range(nb_total_epochs):
if i%N == 0:
X_train, y_train = generate_new_dataset(simulation_parameters)
train_model(X_train, y_train)
使用fit()
函数是否有任何现有技巧?
答案 0 :(得分:3)
使用Sequence
创建数据集并将其传递给fit_generator
。定义on_epoch_end
方法以修改某些时期的数据集。
每个
Sequence
必须实现__getitem__
和__len__
方法。 如果您想在时期之间修改数据集,可以实施on_epoch_end
。方法__getitem__
应该返回完整的批次。
此外,您可以安全地将Sequence
用于多处理数据处理:
keras.utils.Sequence
的使用保证了排序,并保证在使用use_multiprocessing=True
时单个使用每个时期的每个输入。
从Sequence
文档略微修改,以包含on_epoch_end
。
class CIFAR10Sequence(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.epoch = 0
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
def on_epoch_end(self):
if self.epoch % N == 0:
pass
# modify data
self.epoch += 1