无法保存使用Keras功能API制作的模型

时间:2018-10-06 19:33:45

标签: python python-3.x tensorflow keras

专门尝试保存此处实现的整个MaskRCNN模型 https://github.com/matterport/Mask_RCNN

https://github.com/matterport/Mask_RCNN/blob/master/mrcnn/model.py的第2343行基本上将save_weights_onlyFalse更改为True,从而保存了整个模型。

keras.callbacks.ModelCheckpoint(self.checkpoint_path, verbose=0, save_weights_only=False),

错误的堆栈跟踪低于

  File "./samples/coco/coco.py", line 509, in <module>
    augmentation=augmentation)
  File "/mask_rcnn_root/Mask_RCNN/mrcnn/model.py", line 2374, in train
    use_multiprocessing=True,
  File "/usr/local/lib/python3.5/dist-packages/keras/legacy/interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 1415, in fit_generator
    initial_epoch=initial_epoch)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training_generator.py", line 247, in fit_generator
    callbacks.on_epoch_end(epoch, epoch_logs)
  File "/usr/local/lib/python3.5/dist-packages/keras/callbacks.py", line 77, in on_epoch_end
    callback.on_epoch_end(epoch, logs)
  File "/usr/local/lib/python3.5/dist-packages/keras/callbacks.py", line 455, in on_epoch_end
    self.model.save(filepath, overwrite=True)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/network.py", line 1085, in save
    save_model(self, filepath, overwrite, include_optimizer)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/saving.py", line 116, in save_model
    'config': model.get_config()
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/network.py", line 926, in get_config
    return copy.deepcopy(config)
  File "/usr/lib/python3.5/copy.py", line 155, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.5/copy.py", line 243, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.5/copy.py", line 155, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.5/copy.py", line 218, in _deepcopy_list
    y.append(deepcopy(a, memo))
  File "/usr/lib/python3.5/copy.py", line 155, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.5/copy.py", line 243, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.5/copy.py", line 155, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.5/copy.py", line 243, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.5/copy.py", line 155, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.5/copy.py", line 223, in _deepcopy_tuple
    y = [deepcopy(a, memo) for a in x]
  File "/usr/lib/python3.5/copy.py", line 223, in <listcomp>
    y = [deepcopy(a, memo) for a in x]
  File "/usr/lib/python3.5/copy.py", line 155, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.5/copy.py", line 223, in _deepcopy_tuple
    y = [deepcopy(a, memo) for a in x]
  File "/usr/lib/python3.5/copy.py", line 223, in <listcomp>
    y = [deepcopy(a, memo) for a in x]
  File "/usr/lib/python3.5/copy.py", line 182, in deepcopy
    y = _reconstruct(x, rv, 1, memo)
  File "/usr/lib/python3.5/copy.py", line 297, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.5/copy.py", line 155, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.5/copy.py", line 243, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.5/copy.py", line 182, in deepcopy
    y = _reconstruct(x, rv, 1, memo)
  File "/usr/lib/python3.5/copy.py", line 297, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.5/copy.py", line 155, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.5/copy.py", line 243, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.5/copy.py", line 174, in deepcopy
    rv = reductor(4)
TypeError: can't pickle SwigPyObject objects keras 

谢谢!

1 个答案:

答案 0 :(得分:0)

基本上,原因是Keras中Lambda层的使用不当会破坏模型的保存。尽管可以使用model.save_weights("my_model.h5")来保存权重,但是如果尝试保存整个模型或提取图形结构,则会导致崩溃。因此,以下所有情况均因我的情况而失败

model.save('my_model.h5') 
json_string = model.to_json() 
yaml_string = model.to_yaml() 

更多详细信息在这里 https://github.com/keras-team/keras/issues/11020#issuecomment-427638145