我想在批处理生成器中使用model.predict
,有什么可能的方法实现呢?
似乎一种选择是在init和纪元末期加载模型:
class DataGenerator(keras.utils.Sequence):
def __init__(self, model_name):
# Load model
# ...
def on_epoch_end(self):
# Load model
答案 0 :(得分:1)
根据我的经验,在训练时预测另一个模型会带来错误。
您可能应该简单地将训练模型附加到生成器模型之后。
假设您有:
generator_model (the one you want to use inside the generator)
training_model (the one you want to train)
然后
generatorInput = Input(shapeOfTheGeneratorInput)
generatorOutput = generator_model(generatorInput)
trainingOutput = training_model(generatorOutput)
entireModel = Model(generatorInput,trainingOutput)
在编译之前,请确保生成器模型的所有层均不可训练:
genModel = entireModel.layers[1]
for l in genModel.layers:
l.trainable = False
entireModel.compile(optimizer=optimizer,loss=loss)
现在请定期使用发电机。
class DataGenerator(keras.utils.Sequence):
def __init__(self, model_name, modelInputs, batchSize):
self.genModel = load_model(model_name)
self.inputs = modelInputs
self.batchSize = batchSize
def __len__(self):
l,rem = divmod(len(self.inputs), self.batchSize)
return (l + (1 if rem > 0 else 0))
def __getitem__(self,i):
items = self.inputs[i*self.batchSize:(i+1)*self.batchSize]
items = doThingsWithItems(items)
predItems = self.genModel.predict_on_batch(items)
#the following is the only reason not to chain models
predItems = doMoreThingsWithItems(predItems)
#do something to get Y_train_items as well
return predItems, y_train_items
如果确实发现了我提到的错误,则可以牺牲并行生成功能并执行一些手动循环:
for e in range(epochs):
for i in range(batches):
x,y = generator[i]
model.train_on_batch(x,y)