我不了解使用layer.get_weights()获得的权重的尺寸

时间:2019-02-11 04:02:53

标签: python-3.x keras conv-neural-network recurrent-neural-network keras-layer

我在集群上运行了一个模型,并将Keras模型保存为“ .h5”文件。 现在,我正在重新加载模型,并希望可视化该模型中的某些内核。因此,我运行以下命令来获取网络各层的配置。但是我不知道每个权重是什么目的。

model_1.layers[40].get_config()

{'activation': 'tanh',
 'activity_regularizer': None,
 'bias_constraint': None,
 'bias_initializer': {'class_name': 'Zeros', 'config': {}},
 'bias_regularizer': None,
 'data_format': 'channels_last',
 'dilation_rate': (1, 1),
 'dropout': 0.0,
 'filters': 8,
 'go_backwards': False,
 'kernel_constraint': None,
 'kernel_initializer': {'class_name': 'VarianceScaling',
  'config': {'distribution': 'uniform',
   'mode': 'fan_avg',
   'scale': 1.0,
   'seed': None}},
 'kernel_regularizer': None,
 'kernel_size': (3, 3),
 'name': 'convlstm2d_3_6',
 'padding': 'same',
 'recurrent_activation': 'hard_sigmoid',
 'recurrent_constraint': None,
 'recurrent_dropout': 0.0,
 'recurrent_initializer': {'class_name': 'Orthogonal',
  'config': {'gain': 1.0, 'seed': None}},
 'recurrent_regularizer': None,
 'return_sequences': False,
 'return_state': False,
 'stateful': False,
 'strides': (1, 1),
 'trainable': True,
 'unit_forget_bias': True,
 'unroll': False,
 'use_bias': True}

并且我正在使用以下命令进行举重:

convlstm_3_6 = model_1.layers[40].get_weights()

print(len(convlstm_3_6))
print(len(convlstm_3_6[0]))
print(len(convlstm_3_6[1]))
print(len(convlstm_3_6[2]))

3
3
3
32

print(convlstm_3_6[1][0][0].shape)
print(convlstm_3_6[1][1][0].shape)
print(convlstm_3_6[1][2][0].shape)

(8, 32)
(8, 32)
(8, 32)

因此,我希望有8个3乘3的矩阵,但我不知道32是从哪里来的。 我对模型的输入是62 x62。您还可以在下面看到模型架构的输出:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, None, 62, 62, 0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, None, 62, 62, 0                                            
__________________________________________________________________________________________________
input_3 (InputLayer)            (None, None, 62, 62, 0                                            
__________________________________________________________________________________________________
input_4 (InputLayer)            (None, None, 62, 62, 0                                            
__________________________________________________________________________________________________
input_5 (InputLayer)            (None, None, 62, 62, 0                                            
__________________________________________________________________________________________________
input_6 (InputLayer)            (None, None, 62, 62, 0                                            
__________________________________________________________________________________________________
input_7 (InputLayer)            (None, None, 62, 62, 0                                            
__________________________________________________________________________________________________
convlstm2d_1_1 (ConvLSTM2D)     (None, None, 62, 62, 50816       input_1[0][0]                    
__________________________________________________________________________________________________
convlstm2d_1_2 (ConvLSTM2D)     (None, None, 62, 62, 50816       input_2[0][0]                    
__________________________________________________________________________________________________
convlstm2d_1_3 (ConvLSTM2D)     (None, None, 62, 62, 50816       input_3[0][0]                    
__________________________________________________________________________________________________
convlstm2d_1_4 (ConvLSTM2D)     (None, None, 62, 62, 50816       input_4[0][0]                    
__________________________________________________________________________________________________
convlstm2d_1_5 (ConvLSTM2D)     (None, None, 62, 62, 50816       input_5[0][0]                    
__________________________________________________________________________________________________
convlstm2d_1_6 (ConvLSTM2D)     (None, None, 62, 62, 50816       input_6[0][0]                    
__________________________________________________________________________________________________
convlstm2d_1_7 (ConvLSTM2D)     (None, None, 62, 62, 50816       input_7[0][0]                    
__________________________________________________________________________________________________
BatchNormal_1_1 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_1_1[0][0]             
__________________________________________________________________________________________________
BatchNormal_1_2 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_1_2[0][0]             
__________________________________________________________________________________________________
BatchNormal_1_3 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_1_3[0][0]             
__________________________________________________________________________________________________
BatchNormal_1_4 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_1_4[0][0]             
__________________________________________________________________________________________________
BatchNormal_1_5 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_1_5[0][0]             
__________________________________________________________________________________________________
BatchNormal_1_6 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_1_6[0][0]             
__________________________________________________________________________________________________
BatchNormal_1_7 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_1_7[0][0]             
__________________________________________________________________________________________________
convlstm2d_2_1 (ConvLSTM2D)     (None, None, 62, 62, 73856       BatchNormal_1_1[0][0]            
__________________________________________________________________________________________________
convlstm2d_2_2 (ConvLSTM2D)     (None, None, 62, 62, 73856       BatchNormal_1_2[0][0]            
__________________________________________________________________________________________________
convlstm2d_2_3 (ConvLSTM2D)     (None, None, 62, 62, 73856       BatchNormal_1_3[0][0]            
__________________________________________________________________________________________________
convlstm2d_2_4 (ConvLSTM2D)     (None, None, 62, 62, 73856       BatchNormal_1_4[0][0]            
__________________________________________________________________________________________________
convlstm2d_2_5 (ConvLSTM2D)     (None, None, 62, 62, 73856       BatchNormal_1_5[0][0]            
__________________________________________________________________________________________________
convlstm2d_2_6 (ConvLSTM2D)     (None, None, 62, 62, 73856       BatchNormal_1_6[0][0]            
__________________________________________________________________________________________________
convlstm2d_2_7 (ConvLSTM2D)     (None, None, 62, 62, 73856       BatchNormal_1_7[0][0]            
__________________________________________________________________________________________________
BatchNormal_2_1 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_2_1[0][0]             
__________________________________________________________________________________________________
BatchNormal_2_2 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_2_2[0][0]             
__________________________________________________________________________________________________
BatchNormal_2_3 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_2_3[0][0]             
__________________________________________________________________________________________________
BatchNormal_2_4 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_2_4[0][0]             
__________________________________________________________________________________________________
BatchNormal_2_5 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_2_5[0][0]             
__________________________________________________________________________________________________
BatchNormal_2_6 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_2_6[0][0]             
__________________________________________________________________________________________________
BatchNormal_2_7 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_2_7[0][0]             
__________________________________________________________________________________________________
convlstm2d_3_1 (ConvLSTM2D)     (None, 62, 62, 8)    11552       BatchNormal_2_1[0][0]            
__________________________________________________________________________________________________
convlstm2d_3_2 (ConvLSTM2D)     (None, 62, 62, 8)    11552       BatchNormal_2_2[0][0]            
__________________________________________________________________________________________________
convlstm2d_3_3 (ConvLSTM2D)     (None, 62, 62, 8)    11552       BatchNormal_2_3[0][0]            
__________________________________________________________________________________________________
convlstm2d_3_4 (ConvLSTM2D)     (None, 62, 62, 8)    11552       BatchNormal_2_4[0][0]            
__________________________________________________________________________________________________
convlstm2d_3_5 (ConvLSTM2D)     (None, 62, 62, 8)    11552       BatchNormal_2_5[0][0]            
__________________________________________________________________________________________________
convlstm2d_3_6 (ConvLSTM2D)     (None, 62, 62, 8)    11552       BatchNormal_2_6[0][0]            
__________________________________________________________________________________________________
convlstm2d_3_7 (ConvLSTM2D)     (None, 62, 62, 8)    11552       BatchNormal_2_7[0][0]            
__________________________________________________________________________________________________
BatchNormal_3_1 (BatchNormaliza (None, 62, 62, 8)    32          convlstm2d_3_1[0][0]             
__________________________________________________________________________________________________
BatchNormal_3_2 (BatchNormaliza (None, 62, 62, 8)    32          convlstm2d_3_2[0][0]             
__________________________________________________________________________________________________
BatchNormal_3_3 (BatchNormaliza (None, 62, 62, 8)    32          convlstm2d_3_3[0][0]             
__________________________________________________________________________________________________
BatchNormal_3_4 (BatchNormaliza (None, 62, 62, 8)    32          convlstm2d_3_4[0][0]             
__________________________________________________________________________________________________
BatchNormal_3_5 (BatchNormaliza (None, 62, 62, 8)    32          convlstm2d_3_5[0][0]             
__________________________________________________________________________________________________
BatchNormal_3_6 (BatchNormaliza (None, 62, 62, 8)    32          convlstm2d_3_6[0][0]             
__________________________________________________________________________________________________
BatchNormal_3_7 (BatchNormaliza (None, 62, 62, 8)    32          convlstm2d_3_7[0][0]             
__________________________________________________________________________________________________
Output1 (Conv2D)                (None, 62, 62, 1)    73          BatchNormal_3_1[0][0]            
__________________________________________________________________________________________________
Output2 (Conv2D)                (None, 62, 62, 1)    73          BatchNormal_3_2[0][0]            
__________________________________________________________________________________________________
Output3 (Conv2D)                (None, 62, 62, 1)    73          BatchNormal_3_3[0][0]            
__________________________________________________________________________________________________
Output4 (Conv2D)                (None, 62, 62, 1)    73          BatchNormal_3_4[0][0]            
__________________________________________________________________________________________________
Output5 (Conv2D)                (None, 62, 62, 1)    73          BatchNormal_3_5[0][0]            
__________________________________________________________________________________________________
Output6 (Conv2D)                (None, 62, 62, 1)    73          BatchNormal_3_6[0][0]            
__________________________________________________________________________________________________
Output7 (Conv2D)                (None, 62, 62, 1)    73          BatchNormal_3_7[0][0]            
__________________________________________________________________________________________________
other_input (InputLayer)        (None, 62, 62, 18)   0                                            
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 62, 62, 25)   0           Output1[0][0]                    
                                                                 Output2[0][0]                    
                                                                 Output3[0][0]                    
                                                                 Output4[0][0]                    
                                                                 Output5[0][0]                    
                                                                 Output6[0][0]                    
                                                                 Output7[0][0]                    
                                                                 other_input[0][0]                
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 62, 62, 16)   40016       concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 62, 62, 8)    3208        conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 62, 62, 1)    201         conv2d_2[0][0]                   
==================================================================================================
Total params: 999,520
Trainable params: 998,512
Non-trainable params: 1,008

这就是我定义ConvLSTM部分的方式:

img_size = 62
Channels = 12


first_input = Input(shape=(None, img_size, img_size, Channels))
second_input = Input(shape=(None, img_size, img_size, Channels))
third_input = Input(shape=(None, img_size, img_size, Channels))
fourth_input = Input(shape=(None, img_size, img_size, Channels))
fifth_input = Input(shape=(None, img_size, img_size, Channels))
sixth_input = Input(shape=(None, img_size, img_size, Channels))
seventh_input = Input(shape=(None, img_size, img_size, Channels))



n_filters_1 = 32
n_filters_2 = 16
n_filters_3 = 8



def set_ConvLSTM_model(ConvLSTM_input, pattern):
    #First layer
    model_convlstm_1 = ConvLSTM2D(filters=n_filters_1, kernel_size=(3, 3), activation='sigmoid',
                                    padding='same', return_sequences=True, name='convlstm2d_1_' + str(pattern))(ConvLSTM_input)
    model_BatchNormal_1 = BatchNormalization(name='BatchNormal_1_' + str(pattern))(model_convlstm_1)

    #Second layer
    model_convlstm_2 = ConvLSTM2D(filters=n_filters_2, kernel_size=(3, 3),
                                    padding='same', return_sequences=True, name='convlstm2d_2_' + str(pattern))(model_BatchNormal_1)
    model_BatchNormal_2 = BatchNormalization(name='BatchNormal_2_' + str(pattern))(model_convlstm_2)

    #Third layer
    model_convlstm_3 = ConvLSTM2D(filters=n_filters_3, kernel_size=(3, 3),
                                    padding='same', return_sequences=False, name='convlstm2d_3_' + str(pattern))(model_BatchNormal_2)
    model_BatchNormal_3 = BatchNormalization(name='BatchNormal_3_' + str(pattern))(model_convlstm_3)

    #Last layer convolutional model
    model_conv_1 = Conv2D(filters=1, kernel_size=(3, 3),
                            activation='sigmoid',
                            padding='same', data_format='channels_last', name='Output' + str(pattern))(model_BatchNormal_3)
    return model_conv_1

0 个答案:

没有答案