如何在Keras中将遮罩层应用于顺序CNN模型?

时间:2018-12-30 12:33:11

标签: keras conv-neural-network lstm mask masking

我在将RNN / LSTM模型中的CNN上应用遮罩层时遇到问题。

我的数据不是原始图像,但是我转换为(16,34,4)(channels_first)的形状。数据是连续的,最长步长为22。因此,对于不变的方式,我将时间步长设置为22。由于它可能短于22个步长,因此我将其他值填充为np.zeros。但是,对于0个填充数据,它大约是所有数据集中的一半,因此,对于0个填充,使用如此多的无用数据无法获得很好的训练结果。然后我想添加一个掩码来取消这0个填充数据。

这是我的代码。

mask = np.zeros((16,34,4), dtype = np.int8)  
input_shape = (22, 16, 34, 4)  
model = Sequential()  
model.add(TimeDistributed(Masking(mask_value=mask), input_shape=input_shape, name = 'mask'))  
model.add(TimeDistributed(Conv2D(100, (5, 2), data_format = 'channels_first', activation = relu), name = 'conv1'))  
model.add(TimeDistributed(BatchNormalization(), name = 'bn1'))  
model.add(Dropout(0.5, name = 'drop1'))  
model.add(TimeDistributed(Conv2D(100, (5, 2), data_format = 'channels_first', activation = relu), name ='conv2'))  
model.add(TimeDistributed(BatchNormalization(), name = 'bn2'))  
model.add(Dropout(0.5, name = 'drop2'))  
model.add(TimeDistributed(Conv2D(100, (5, 2), data_format = 'channels_first', activation = relu), name ='conv3'))  
model.add(TimeDistributed(BatchNormalization(), name = 'bn3'))  
model.add(Dropout(0.5, name = 'drop3'))  
model.add(TimeDistributed(Flatten(), name = 'flatten'))  
model.add(GRU(256, activation='tanh', return_sequences=True, name = 'gru'))  
model.add(Dropout(0.4, name = 'drop_gru'))  
model.add(Dense(35, activation = 'softmax', name = 'softmax'))  
model.compile(optimizer='Adam',loss='categorical_crossentropy',metrics=['acc'])

这是模型结构。
model.summary():

_________________________________________________________________  
Layer (type)                 Output Shape              Param #     
=================================================================  
mask (TimeDist (None, 22, 16, 34, 4)     0           
_________________________________________________________________  
conv1 (TimeDistributed)      (None, 22, 100, 30, 3)    16100       
_________________________________________________________________  
bn1 (TimeDistributed)        (None, 22, 100, 30, 3)    12          
_________________________________________________________________  
drop1 (Dropout)              (None, 22, 100, 30, 3)    0           
_________________________________________________________________  
conv2 (TimeDistributed)      (None, 22, 100, 26, 2)    100100      
_________________________________________________________________  
bn2 (TimeDistributed)        (None, 22, 100, 26, 2)    8           
_________________________________________________________________  
drop2 (Dropout)              (None, 22, 100, 26, 2)    0           
_________________________________________________________________  
conv3 (TimeDistributed)      (None, 22, 100, 22, 1)    100100      
_________________________________________________________________  
bn3 (TimeDistributed)        (None, 22, 100, 22, 1)    4           
_________________________________________________________________  
drop3 (Dropout)              (None, 22, 100, 22, 1)    0           
_________________________________________________________________  
flatten (TimeDistributed)    (None, 22, 2200)          0           
_________________________________________________________________  
gru (GRU)                    (None, 22, 256)           1886976     
_________________________________________________________________  
drop_gru (Dropout)           (None, 22, 256)           0           
_________________________________________________________________  
softmax (Dense)              (None, 22, 35)            8995        
=================================================================  
Total params: 2,112,295  
Trainable params: 2,112,283  
Non-trainable params: 12  
_________________________________________________________________

对于mask_value,我尝试使用0或此掩码结构,但均不起作用,并且仍然在其中填充了一半填充的所有数据中进行训练。
谁能帮我吗?

B.T.W。,我在这里使用TimeDistributed连接RNN,我知道另一个名为ConvLSTM2D的。有人知道区别吗? ConvLSTM2D需要更多的模型参数,并且训练比TimeDistributed慢得多。

1 个答案:

答案 0 :(得分:0)

很遗憾,Keras Conv图层尚不支持遮罩。关于此问题,在Keras Github页面here is the one上已发布了多个问题,其中有关该主题的讨论最多。似乎有一些挂起的实现细节,但问题从未解决。

讨论中提出的解决方法是对序列中的填充字符进行显式嵌入,并进行全局池化。我发现的Here is another解决方法(对我的用例无济于事,但对您可能有帮助)-保留掩码数组以通过乘法合并。

您还可以查看与您类似的this question周围的对话。