Keras 2.2:无法使用imagenet权重加载预制模型

时间:2019-01-27 05:40:40

标签: python tensorflow keras

我有一段代码用于在较旧的Keras版本中工作,但是在Keras 2.2中,我遇到了一个错误,该错误是将没有足够多图层的模型加载到较大的模型中:

import keras
from keras.layers import MaxPooling2D, AveragePooling2D,  Conv2D
from keras.applications import Xception
from keras.layers.normalization import BatchNormalization
from keras.layers import Input, Concatenate, Add
from keras.layers.advanced_activations import LeakyReLU

kernel_size = (3, 3)  
pool_size = (2, 2)  
nfilters = 3
inputs = Input(shape=(331, 331, 1))
x = inputs
x = Conv2D(nfilters, kernel_size, strides=(1,1), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.1)(x)
x = MaxPooling2D(pool_size=pool_size)(x)
x =  Add()([x,AveragePooling2D(pool_size=pool_size)(inputs)])  # residual skip connection on shrunk image
base_model = Xception(weights='imagenet', include_top=False, input_tensor=x)

我得到的错误是与Xception一起使用的:

ValueError: You are trying to load a weight file containing 80 layers into a model with 82 layers.

这是指向Google Colab notebook that reproduces this的链接。

出现有关加载图像网络权重的问题;如果将权重设置为None,就没有问题。

通过传递load_model()可以在by_name=True调用中避免这种错误,但是像Xception这样的预制模型不允许使用by_name关键字。

谁能解释一下如何使我的代码在Keras 2.2下重新工作?

我想我可以定义Xception两次,一次是单独使用imagenet权重,另一次是在我的完整模型中使用weights = None,然后将权重从前者复制到后者...但是我宁愿不这样做如果可能,必须这样做。

“为什么要在Xception之前放置这些图层?” 是因为我正在将较大的图像缩小到Xception需要其图像权重的大小,然后将灰度图像转换为3通道图像。)

1 个答案:

答案 0 :(得分:2)

不完全确定如何解释错误,但是可以通过将Xception模型视为一个层,在先前的层上调用它并将整个堆栈包装在模型实例中来使其起作用。我在您的colab笔记本中验证了以下内容。

import keras
from keras.layers import MaxPooling2D, AveragePooling2D,  Conv2D
from keras.applications import Xception
from keras.layers.normalization import BatchNormalization
from keras.layers import Input, Concatenate, Add
from keras.layers.advanced_activations import LeakyReLU

kernel_size = (3, 3)  
pool_size = (2, 2)  
nfilters = 3
inputs = Input(shape=(331, 331, 1))
x = inputs
x = Conv2D(nfilters, kernel_size, strides=(1,1), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.1)(x)
x = MaxPooling2D(pool_size=pool_size)(x)
x =  Add()([x,AveragePooling2D(pool_size=pool_size)(inputs)])  # residual skip connection on shrunk image

# Xception architecture is just another layer
base_model = Xception(weights='imagenet', include_top=False)
output = base_model(x)
# Wrap everything into a model
combined_model = keras.models.Model(inputs=inputs, outputs=output)

这将为您提供一个如下所示的模型:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            (None, 331, 331, 1)  0                                            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 331, 331, 3)  27          input_2[0][0]                    
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 331, 331, 3)  12          conv2d_6[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU)       (None, 331, 331, 3)  0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 165, 165, 3)  0           leaky_re_lu_2[0][0]              
__________________________________________________________________________________________________
average_pooling2d_2 (AveragePoo (None, 165, 165, 1)  0           input_2[0][0]                    
__________________________________________________________________________________________________
add_14 (Add)                    (None, 165, 165, 3)  0           max_pooling2d_2[0][0]            
                                                                 average_pooling2d_2[0][0]        
__________________________________________________________________________________________________
xception (Model)                multiple             20861480    add_14[0][0]                     
==================================================================================================
Total params: 20,861,519
Trainable params: 20,806,985
Non-trainable params: 54,534
__________________________________________________________________________________________________