尽管我已经编译了Keras顺序模型,但是我遇到了RuntimeError:您必须先编译模型再使用它

时间:2019-02-04 11:51:20

标签: python keras runtime-error

我正在尝试使用CNN-RNN模型对图像标题生成器进行编码。对于RNN部分,我创建了自己的类CaptionGenerator,其方法create_model(如下所示)之一创建了两个顺序模型-图像模型和语言模型,并将它们连接起来。之后,我编译模型。

def create_model(self, ret_model = False):

    image_model = Sequential()
    image_model.add(Dense(EMBEDDING_DIM, input_dim = 4096, activation='relu'))
    image_model.add(RepeatVector(self.max_cap_len))

    lang_model = Sequential()
    lang_model.add(Embedding(self.vocab_size, 256, input_length=self.max_cap_len))
    lang_model.add(LSTM(256,return_sequences=True))
    lang_model.add(TimeDistributed(Dense(EMBEDDING_DIM)))

    model = Sequential()
    model.add(keras.layers.Concatenate([image_model, lang_model]))
    model.add(LSTM(1000,return_sequences=False))
    model.add(Dense(self.vocab_size))
    model.add(Activation('softmax'))

    if(ret_model==True):
        return model

    model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

    return model

在另一个文件中,我创建此类的对象,并使用该对象创建上述模型的对象。

def train_model(weight = None, batch_size=32, epochs = 10):
    cg = caption_generator.CaptionGenerator()
    model = cg.create_model()

    if weight != None:
        model.load_weights(weight)

    counter = 0
    file_name = 'weights-improvement-{epoch:02d}.hdf5'
    checkpoint = ModelCheckpoint(file_name, monitor='loss', verbose=1, 
                 save_best_only=True, mode='min')
    callbacks_list = [checkpoint]
    print(" Fitting the data ") 
    model.fit_generator(cg.data_generator(batch_size=batch_size), 
                        steps_per_epoch=cg.total_samples/batch_size, 
                        epochs=epochs, verbose=2, 
                        callbacks=callbacks_list)
    try:
        model.save('Models/WholeModel.h5', overwrite=True)
        model.save_weights('Models/Weights.h5',overwrite=True)
    except:
        print ( "Error in saving model." )
    print ( "Training complete...\n" )

但是当我调用function model.fit_generator()时,尽管模型是用create_model类的CaptionGenerator方法编译的,但是却遇到以下错误:

RuntimeError: You must compile your model before using it.

Error Stack

Model Architecture

0 个答案:

没有答案