在Keras合并一个向前的lstm和一个向后的lstm

时间:2017-07-01 06:46:28

标签: keras lstm

我想在Keras中合并前向LSTM和后向LSTM。后向LSTM的输入数组与前向LSTM的输入数组不同。因此,我不能使用keras.layers.Bidirectional。

正向输入是(10,4)。 反向输入为(12,4),在放入模型之前反转。我希望在LSTM之后再次将其反转并将其与前锋合并。

简化模型如下。

from lambdawithmask import Lambda as MaskLambda

def reverse_func(x, mask=None):
    return tf.reverse(x, [False, True, False])

forward = Sequential()
backward = Sequential()
model = Sequential()

forward.add(LSTM(input_shape = (10, 4), output_dim = 4, return_sequences = True))
backward.add(LSTM(input_shape = (12, 4), output_dim = 4, return_sequences = True))
backward.add(MaskLambda(function=reverse_func, mask_function=reverse_func))
model.add(Merge([forward, backward], mode = "concat", concat_axis = 1))

运行时,错误信息为: 传递给'ConcatV2'Op的'值'的列表中的张量类型[bool,float32]并不是全部匹配。

有人能帮帮我吗?我使用Keras(2.0.5)在Python 3.5.2中编码,后端是tensorflow(1.2.1)。

2 个答案:

答案 0 :(得分:1)

首先,如果您有两个不同的输入,则无法使用Sequential模型。您必须使用功能API模型:

from keras.models import Model   

两个第一个模型可以是顺序的,没问题,但结必须是常规模型。当它关于连接时,我也使用函数方法(创建图层,然后传递输入):

junction = Concatenate(axis=1)([forward.output,backward.output])

为什么axis = 1?您只能将具有相同形状的事物连接起来。由于你有10和12,它们是不兼容的,除非你使用这个精确的轴进行合并,这是第二个轴,考虑到你有(BatchSize,TimeSteps,Units)

要创建最终模型,请使用Model,指定输入和输出:

model = Model([forward.input,backward.input], junction)

在要反转的模型中,只使用Lambda图层。 MaskLambda不仅仅是你想要的功能。我还建议你使用张量函数的keras后端:

import keras.backend as K

#instead of the MaskLambda:
backward.add(Lambda(lambda x: K.reverse(x,axes=[1]), output_shape=(12,?))

此处,?是LSTM图层的单位数量。最后见PS。

PS:我不确定output_dim在LSTM层中是否有用。在Lambda图层中它是必要的,但我从未在其他任何地方使用它。形状是"单位数量的自然结果。你放入你的图层。奇怪的是,你没有指定单位数量。

PS2:你究竟想要连接两个不同大小的序列?

答案 1 :(得分:1)

如上面的答案中所述,使用Functional API为多输入/输出模型提供了很大的灵活性。您只需将go_backwards参数设置为True即可反转LSTM图层对输入向量的遍历。

我已经定义了下面的smart_merge函数,该函数将前向和后向LSTM层合并在一起并处理单个遍历情况。

from keras.models import Model
from keras.layers import Input, merge

def smart_merge(vectors, **kwargs):
        return vectors[0] if len(vectors)==1 else merge(vectors, **kwargs)      

input1 = Input(shape=(10,4), dtype='int32')
input2 = Input(shape=(12,4), dtype='int32')

LtoR_LSTM = LSTM(56, return_sequences=False)
LtoR_LSTM_vector = LtoR_LSTM(input1)
RtoL_LSTM = LSTM(56, return_sequences=False, go_backwards=True)
RtoL_LSTM_vector = RtoL_LSTM(input2)

BidireLSTM_vector = [LtoR_LSTM_vector]
BidireLSTM_vector.append(RtoL_LSTM_vector)
BidireLSTM_vector= smart_merge(BidireLSTM_vector, mode='concat')