如何将Masking应用于复杂的CNN + LSTM网络架构?

时间:2017-04-25 00:48:04

标签: keras

我有一个处理视频(图像序列)的顺序模型。 我的模型看起来像这样:

Time = 0    Time = 1    ....    Time = t
  im@0         im@1     ....       im@t
    |            |                   |
   CNN          CNN     ....        CNN
    |            |                   |
   RNN--------->RNN---->....------->RNN
                                     |
                               some prediction     

它适用于不变长度序列。但我希望它接受可变长度的序列。我们知道,在keras中,Masking层可以帮助我们处理这种情况,但并不是所有层都支持。我有一个相当复杂的CNN架构,所以我似乎无法在CNN之前使用屏蔽层。

有什么方法可以解决这个问题吗?

[修改] 我发现TimeDistributed Wrapper可以使应用掩码成为可能。但我不知道我的实施是否正确:

cnn = make_basenet(...)            # make backbone network
cnn = TimeDistributed(cnn)         # wrap the cnn with TimeDistributed

seq = Sequential()                 # CNN + LSTM
seq.add(Masking(input_shape=(...)) # Masking for TimeDistributed CNN
seq.add(cnn)                       # add CNN
seq.add(Masking())                 # **This is necessary**
seq.add(LSTM(...))                 # add LSTM

如果我在LSTM之前省略了遮罩层,则会引发错误。但是如果我添加这个Masking层,它就可以了。我想知道我是否以正确的方式实施了我的模型?

0 个答案:

没有答案