keras中LSTM隐藏状态的随机初始化

时间:2020-03-05 15:43:13

标签: tensorflow keras deep-learning lstm

我在音乐生成项目中使用了一个模型。模型创建如下

        self.model.add(LSTM(self.hidden_size, input_shape=(self.input_length,self.notes_classes),return_sequences=True,recurrent_dropout=dropout) ,)
        self.model.add(LSTM(self.hidden_size,recurrent_dropout=dropout,return_sequences=True))
        self.model.add(LSTM(self.hidden_size,return_sequences=True))
        self.model.add(BatchNorm())
        self.model.add(Dropout(dropout))
        self.model.add(Dense(256))
        self.model.add(Activation('relu'))
        self.model.add(BatchNorm())
        self.model.add(Dropout(dropout))
        self.model.add(Dense(256))
        self.model.add(Activation('relu'))
        self.model.add(BatchNorm())
        self.model.add(Dense(self.notes_classes))
        self.model.add(Activation('softmax'))

以70%的精度训练该模型之后,无论何时生成音乐,无论输入的音符如何,它总是给出相同的起始音符而几乎没有变化。我认为可以通过在生成之初初始化LSTM的隐藏状态来解决这种情况。我该怎么办?

1 个答案:

答案 0 :(得分:1)

有两种状态,state_h是最后一步的输出; state_c是随身携带状态或记忆。

您应该使用功能性API模型具有多个输入:

main_input = Input((self.input_length,self.notes_classes))
state_h_input = Input((self.hidden_size,))
state_c_input = Input((self.hidden_size, self.hidden_size))

out = LSTM(self.hidden_size, return_sequences=True,recurrent_dropout=dropout,
           initial_state=[state_h_input, state_c_input])(main_input)

#I'm not changing the following layers, they should have their own states if you want to

out = LSTM(self.hidden_size,recurrent_dropout=dropout,return_sequences=True)(out)
out = LSTM(self.hidden_size,return_sequences=True)(out)
out = BatchNorm()(out)
out = Dropout(dropout)(out)
out = Dense(256)(out)
out = Activation('relu')(out)
out = BatchNorm()(out)
out = Dropout(dropout)(out)
out = Dense(256)(out)
out = Activation('relu')(out)
out = BatchNorm()(out)
out = Dense(self.notes_classes)(out)
out = Activation('softmax')(out)

self.model = Model([main_input, state_h_input, state_c_input], out)

按照这种方法,如果您想要可训练的初始状态,甚至可以将其他层的输出用作初始状态。

最大的变化是您将需要通过状态进行训练和预测:

model.fit([original_inputs, state_h_data, state_c_data], y_train) 

在训练期间您可以在其中将零用于状态。