如何在批处理生成器中使用模型?

时间:2018-12-03 11:24:23

标签: python keras deep-learning

我想在批处理生成器中使用model.predict,有什么可能的方法实现呢?

似乎一种选择是在init和纪元末期加载模型:

class DataGenerator(keras.utils.Sequence):
    def __init__(self, model_name):
        # Load model

    # ...

    def on_epoch_end(self):
        # Load model

1 个答案:

答案 0 :(得分:1)

根据我的经验,在训练时预测另一个模型会带来错误。

您可能应该简单地将训练模型附加到生成器模型之后。

假设您有:

generator_model (the one you want to use inside the generator)    
training_model (the one you want to train)

然后

generatorInput = Input(shapeOfTheGeneratorInput)
generatorOutput = generator_model(generatorInput)
trainingOutput = training_model(generatorOutput)

entireModel = Model(generatorInput,trainingOutput)

在编译之前,请确保生成器模型的所有层均不可训练:

genModel = entireModel.layers[1]
for l in genModel.layers:
    l.trainable = False

entireModel.compile(optimizer=optimizer,loss=loss)

现在请定期使用发电机。


在生成器内部进行预测:

class DataGenerator(keras.utils.Sequence):

    def __init__(self, model_name, modelInputs, batchSize):
        self.genModel = load_model(model_name)
        self.inputs = modelInputs
        self.batchSize = batchSize


    def __len__(self):
        l,rem = divmod(len(self.inputs), self.batchSize)
        return (l + (1 if rem > 0 else 0))

    def __getitem__(self,i):

        items = self.inputs[i*self.batchSize:(i+1)*self.batchSize]
        items = doThingsWithItems(items)

        predItems = self.genModel.predict_on_batch(items)

        #the following is the only reason not to chain models
        predItems = doMoreThingsWithItems(predItems)

        #do something to get Y_train_items as well

        return predItems, y_train_items

如果确实发现了我提到的错误,则可以牺牲并行生成功能并执行一些手动循环:

for e in range(epochs):
    for i in range(batches):
        x,y = generator[i]
        model.train_on_batch(x,y)