在训练和预测期间保存Keras RNN层状态的最佳方法是什么?

时间:2020-09-25 20:25:54

标签: python tensorflow keras lstm lstm-stateful

我有一个LSTM处理来自多个用户的无限事件流。因此,事件在到达时每次处理一个时间步。每个用户的行为可能有所不同,因此我需要记住该用户先前事件中LSTM的状态。这是我打算如何做的概述

    def LSTM_MODEL(kernel_sizes,
             input_shape,
             initial_state=None,):


inputs = Input(input_shape[1],input_shape[2])

L1 = LSTM(kernal_sizes[0], activation ='relu',return_sequences=True,kernel_regularizer=regularizers.l2(0.00),initial_state = initial_state[0])(inputs)
L2= LSTM(kernal_sizes[1], activation='relu',return_sequences=False, initial_state = initial_state[1])(L1)
L3 = RepeatVector(input_shape[1])(L2)
L4 = LSTM(kernal_sizes[1], activation = 'relu',return_sequences=True(L3)
L5 = LSTM(kernal_sizes[0], activation = 'relu',return_sequences=True)(L4)
output = Dense(input_shape[2])(L5)
          
existing_state = [L1.states,L2.states]

model = tf.keras.Model(inputs=input_chars, outputs=[output,existing_state])

当我致电model.fit时,如何在训练期间访问每个批次的状态?新事件到来时,我需要保留一些{user_id:previous_state}的字典,但是我不确定如何实现。

我曾经考虑过在训练过程中简单地使用有状态LSTM,但是在完成每个user_id的完整序列后,我仍然不得不在训练过程中重置状态,因为我的能力超出了单个批次。

指导将不胜感激!

0 个答案:

没有答案