迁移学习可训练模型在保存时引发错误

时间:2021-02-27 02:44:57

标签: python tensorflow keras deep-learning oserror

我已经下载了强文本一个预训练模型,我正在尝试迁移学习它。 因此我正在加载保存为“xray_model.h5”文件的模型,并将其设置为不可训练:

model = tf.keras.models.load_model('xray_model.h5')
model.trainable = False

稍后我使用开始层和结束层并在其上构建我的添加:

base_input = model.layers[0].input
base_output = model.get_layer(name="flatten").output

base_output = build_model()(base_output)

new_model = keras.Model(inputs=base_input, outputs=base_output)

因为我想训练我的层(在一些游戏之后,我意识到我可能也需要训练旧层)我想将模型设置为可训练:

for i in range(len(new_model.layers)):
    new_model._layers[i].trainable = True

但是,当我开始训练它时,回调:

METRICS = ['accuracy',
           tf.keras.metrics.Precision(name='precision'),
           tf.keras.metrics.Recall(name='recall'),
           lr_metric]

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, min_lr=0.00001, verbose=1)

save_callback = tf.keras.callbacks.ModelCheckpoint("new_xray_model.h5",
                                                   save_best_only=True,
                                                   monitor='accuracy')
history = new_model.fit(train_generator,
                        verbose=1,
                        steps_per_epoch=BATCH_SIZE,
                        epochs=EPOCHS,
                        validation_data=test_generator,
                        callbacks=[save_callback, reduce_lr])

我收到下一个错误:

File "C:\Users\jm10o\AppData\Local\Programs\Python\Python38\lib\site-packages\h5py\_hl\group.py", line 373, in __setitem__
    h5o.link(obj.id, self.id, name, lcpl=lcpl, lapl=self._lapl)
  File "h5py\_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py\_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py\h5o.pyx", line 202, in h5py.h5o.link
OSError: Unable to create link (name already exists)

Process finished with exit code 1

我注意到只有在我尝试进一步训练我加载的模型时才会发生这种情况。 我找不到任何解决方案。

1 个答案:

答案 0 :(得分:0)

问题来自 Model_checkpoint 回调。对于每个时期,您都使用相同的名称保存模型。

使用以下格式

ModelCheckpoint('your_model_name{epoch:0d}.h5',
                    monitor='accuracy')