我在集群上运行了一个模型,并将Keras模型保存为“ .h5”文件。 现在,我正在重新加载模型,并希望可视化该模型中的某些内核。因此,我运行以下命令来获取网络各层的配置。但是我不知道每个权重是什么目的。
model_1.layers[40].get_config()
{'activation': 'tanh',
'activity_regularizer': None,
'bias_constraint': None,
'bias_initializer': {'class_name': 'Zeros', 'config': {}},
'bias_regularizer': None,
'data_format': 'channels_last',
'dilation_rate': (1, 1),
'dropout': 0.0,
'filters': 8,
'go_backwards': False,
'kernel_constraint': None,
'kernel_initializer': {'class_name': 'VarianceScaling',
'config': {'distribution': 'uniform',
'mode': 'fan_avg',
'scale': 1.0,
'seed': None}},
'kernel_regularizer': None,
'kernel_size': (3, 3),
'name': 'convlstm2d_3_6',
'padding': 'same',
'recurrent_activation': 'hard_sigmoid',
'recurrent_constraint': None,
'recurrent_dropout': 0.0,
'recurrent_initializer': {'class_name': 'Orthogonal',
'config': {'gain': 1.0, 'seed': None}},
'recurrent_regularizer': None,
'return_sequences': False,
'return_state': False,
'stateful': False,
'strides': (1, 1),
'trainable': True,
'unit_forget_bias': True,
'unroll': False,
'use_bias': True}
并且我正在使用以下命令进行举重:
convlstm_3_6 = model_1.layers[40].get_weights()
print(len(convlstm_3_6))
print(len(convlstm_3_6[0]))
print(len(convlstm_3_6[1]))
print(len(convlstm_3_6[2]))
3
3
3
32
print(convlstm_3_6[1][0][0].shape)
print(convlstm_3_6[1][1][0].shape)
print(convlstm_3_6[1][2][0].shape)
(8, 32)
(8, 32)
(8, 32)
因此,我希望有8个3乘3的矩阵,但我不知道32是从哪里来的。 我对模型的输入是62 x62。您还可以在下面看到模型架构的输出:
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, None, 62, 62, 0
__________________________________________________________________________________________________
input_2 (InputLayer) (None, None, 62, 62, 0
__________________________________________________________________________________________________
input_3 (InputLayer) (None, None, 62, 62, 0
__________________________________________________________________________________________________
input_4 (InputLayer) (None, None, 62, 62, 0
__________________________________________________________________________________________________
input_5 (InputLayer) (None, None, 62, 62, 0
__________________________________________________________________________________________________
input_6 (InputLayer) (None, None, 62, 62, 0
__________________________________________________________________________________________________
input_7 (InputLayer) (None, None, 62, 62, 0
__________________________________________________________________________________________________
convlstm2d_1_1 (ConvLSTM2D) (None, None, 62, 62, 50816 input_1[0][0]
__________________________________________________________________________________________________
convlstm2d_1_2 (ConvLSTM2D) (None, None, 62, 62, 50816 input_2[0][0]
__________________________________________________________________________________________________
convlstm2d_1_3 (ConvLSTM2D) (None, None, 62, 62, 50816 input_3[0][0]
__________________________________________________________________________________________________
convlstm2d_1_4 (ConvLSTM2D) (None, None, 62, 62, 50816 input_4[0][0]
__________________________________________________________________________________________________
convlstm2d_1_5 (ConvLSTM2D) (None, None, 62, 62, 50816 input_5[0][0]
__________________________________________________________________________________________________
convlstm2d_1_6 (ConvLSTM2D) (None, None, 62, 62, 50816 input_6[0][0]
__________________________________________________________________________________________________
convlstm2d_1_7 (ConvLSTM2D) (None, None, 62, 62, 50816 input_7[0][0]
__________________________________________________________________________________________________
BatchNormal_1_1 (BatchNormaliza (None, None, 62, 62, 128 convlstm2d_1_1[0][0]
__________________________________________________________________________________________________
BatchNormal_1_2 (BatchNormaliza (None, None, 62, 62, 128 convlstm2d_1_2[0][0]
__________________________________________________________________________________________________
BatchNormal_1_3 (BatchNormaliza (None, None, 62, 62, 128 convlstm2d_1_3[0][0]
__________________________________________________________________________________________________
BatchNormal_1_4 (BatchNormaliza (None, None, 62, 62, 128 convlstm2d_1_4[0][0]
__________________________________________________________________________________________________
BatchNormal_1_5 (BatchNormaliza (None, None, 62, 62, 128 convlstm2d_1_5[0][0]
__________________________________________________________________________________________________
BatchNormal_1_6 (BatchNormaliza (None, None, 62, 62, 128 convlstm2d_1_6[0][0]
__________________________________________________________________________________________________
BatchNormal_1_7 (BatchNormaliza (None, None, 62, 62, 128 convlstm2d_1_7[0][0]
__________________________________________________________________________________________________
convlstm2d_2_1 (ConvLSTM2D) (None, None, 62, 62, 73856 BatchNormal_1_1[0][0]
__________________________________________________________________________________________________
convlstm2d_2_2 (ConvLSTM2D) (None, None, 62, 62, 73856 BatchNormal_1_2[0][0]
__________________________________________________________________________________________________
convlstm2d_2_3 (ConvLSTM2D) (None, None, 62, 62, 73856 BatchNormal_1_3[0][0]
__________________________________________________________________________________________________
convlstm2d_2_4 (ConvLSTM2D) (None, None, 62, 62, 73856 BatchNormal_1_4[0][0]
__________________________________________________________________________________________________
convlstm2d_2_5 (ConvLSTM2D) (None, None, 62, 62, 73856 BatchNormal_1_5[0][0]
__________________________________________________________________________________________________
convlstm2d_2_6 (ConvLSTM2D) (None, None, 62, 62, 73856 BatchNormal_1_6[0][0]
__________________________________________________________________________________________________
convlstm2d_2_7 (ConvLSTM2D) (None, None, 62, 62, 73856 BatchNormal_1_7[0][0]
__________________________________________________________________________________________________
BatchNormal_2_1 (BatchNormaliza (None, None, 62, 62, 128 convlstm2d_2_1[0][0]
__________________________________________________________________________________________________
BatchNormal_2_2 (BatchNormaliza (None, None, 62, 62, 128 convlstm2d_2_2[0][0]
__________________________________________________________________________________________________
BatchNormal_2_3 (BatchNormaliza (None, None, 62, 62, 128 convlstm2d_2_3[0][0]
__________________________________________________________________________________________________
BatchNormal_2_4 (BatchNormaliza (None, None, 62, 62, 128 convlstm2d_2_4[0][0]
__________________________________________________________________________________________________
BatchNormal_2_5 (BatchNormaliza (None, None, 62, 62, 128 convlstm2d_2_5[0][0]
__________________________________________________________________________________________________
BatchNormal_2_6 (BatchNormaliza (None, None, 62, 62, 128 convlstm2d_2_6[0][0]
__________________________________________________________________________________________________
BatchNormal_2_7 (BatchNormaliza (None, None, 62, 62, 128 convlstm2d_2_7[0][0]
__________________________________________________________________________________________________
convlstm2d_3_1 (ConvLSTM2D) (None, 62, 62, 8) 11552 BatchNormal_2_1[0][0]
__________________________________________________________________________________________________
convlstm2d_3_2 (ConvLSTM2D) (None, 62, 62, 8) 11552 BatchNormal_2_2[0][0]
__________________________________________________________________________________________________
convlstm2d_3_3 (ConvLSTM2D) (None, 62, 62, 8) 11552 BatchNormal_2_3[0][0]
__________________________________________________________________________________________________
convlstm2d_3_4 (ConvLSTM2D) (None, 62, 62, 8) 11552 BatchNormal_2_4[0][0]
__________________________________________________________________________________________________
convlstm2d_3_5 (ConvLSTM2D) (None, 62, 62, 8) 11552 BatchNormal_2_5[0][0]
__________________________________________________________________________________________________
convlstm2d_3_6 (ConvLSTM2D) (None, 62, 62, 8) 11552 BatchNormal_2_6[0][0]
__________________________________________________________________________________________________
convlstm2d_3_7 (ConvLSTM2D) (None, 62, 62, 8) 11552 BatchNormal_2_7[0][0]
__________________________________________________________________________________________________
BatchNormal_3_1 (BatchNormaliza (None, 62, 62, 8) 32 convlstm2d_3_1[0][0]
__________________________________________________________________________________________________
BatchNormal_3_2 (BatchNormaliza (None, 62, 62, 8) 32 convlstm2d_3_2[0][0]
__________________________________________________________________________________________________
BatchNormal_3_3 (BatchNormaliza (None, 62, 62, 8) 32 convlstm2d_3_3[0][0]
__________________________________________________________________________________________________
BatchNormal_3_4 (BatchNormaliza (None, 62, 62, 8) 32 convlstm2d_3_4[0][0]
__________________________________________________________________________________________________
BatchNormal_3_5 (BatchNormaliza (None, 62, 62, 8) 32 convlstm2d_3_5[0][0]
__________________________________________________________________________________________________
BatchNormal_3_6 (BatchNormaliza (None, 62, 62, 8) 32 convlstm2d_3_6[0][0]
__________________________________________________________________________________________________
BatchNormal_3_7 (BatchNormaliza (None, 62, 62, 8) 32 convlstm2d_3_7[0][0]
__________________________________________________________________________________________________
Output1 (Conv2D) (None, 62, 62, 1) 73 BatchNormal_3_1[0][0]
__________________________________________________________________________________________________
Output2 (Conv2D) (None, 62, 62, 1) 73 BatchNormal_3_2[0][0]
__________________________________________________________________________________________________
Output3 (Conv2D) (None, 62, 62, 1) 73 BatchNormal_3_3[0][0]
__________________________________________________________________________________________________
Output4 (Conv2D) (None, 62, 62, 1) 73 BatchNormal_3_4[0][0]
__________________________________________________________________________________________________
Output5 (Conv2D) (None, 62, 62, 1) 73 BatchNormal_3_5[0][0]
__________________________________________________________________________________________________
Output6 (Conv2D) (None, 62, 62, 1) 73 BatchNormal_3_6[0][0]
__________________________________________________________________________________________________
Output7 (Conv2D) (None, 62, 62, 1) 73 BatchNormal_3_7[0][0]
__________________________________________________________________________________________________
other_input (InputLayer) (None, 62, 62, 18) 0
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 62, 62, 25) 0 Output1[0][0]
Output2[0][0]
Output3[0][0]
Output4[0][0]
Output5[0][0]
Output6[0][0]
Output7[0][0]
other_input[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 62, 62, 16) 40016 concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 62, 62, 8) 3208 conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 62, 62, 1) 201 conv2d_2[0][0]
==================================================================================================
Total params: 999,520
Trainable params: 998,512
Non-trainable params: 1,008
这就是我定义ConvLSTM部分的方式:
img_size = 62
Channels = 12
first_input = Input(shape=(None, img_size, img_size, Channels))
second_input = Input(shape=(None, img_size, img_size, Channels))
third_input = Input(shape=(None, img_size, img_size, Channels))
fourth_input = Input(shape=(None, img_size, img_size, Channels))
fifth_input = Input(shape=(None, img_size, img_size, Channels))
sixth_input = Input(shape=(None, img_size, img_size, Channels))
seventh_input = Input(shape=(None, img_size, img_size, Channels))
n_filters_1 = 32
n_filters_2 = 16
n_filters_3 = 8
def set_ConvLSTM_model(ConvLSTM_input, pattern):
#First layer
model_convlstm_1 = ConvLSTM2D(filters=n_filters_1, kernel_size=(3, 3), activation='sigmoid',
padding='same', return_sequences=True, name='convlstm2d_1_' + str(pattern))(ConvLSTM_input)
model_BatchNormal_1 = BatchNormalization(name='BatchNormal_1_' + str(pattern))(model_convlstm_1)
#Second layer
model_convlstm_2 = ConvLSTM2D(filters=n_filters_2, kernel_size=(3, 3),
padding='same', return_sequences=True, name='convlstm2d_2_' + str(pattern))(model_BatchNormal_1)
model_BatchNormal_2 = BatchNormalization(name='BatchNormal_2_' + str(pattern))(model_convlstm_2)
#Third layer
model_convlstm_3 = ConvLSTM2D(filters=n_filters_3, kernel_size=(3, 3),
padding='same', return_sequences=False, name='convlstm2d_3_' + str(pattern))(model_BatchNormal_2)
model_BatchNormal_3 = BatchNormalization(name='BatchNormal_3_' + str(pattern))(model_convlstm_3)
#Last layer convolutional model
model_conv_1 = Conv2D(filters=1, kernel_size=(3, 3),
activation='sigmoid',
padding='same', data_format='channels_last', name='Output' + str(pattern))(model_BatchNormal_3)
return model_conv_1