我正在尝试使用自定义批处理生成器来创建序列到序列模型:
我的模型有一个LSTM层,该层返回序列和最终状态:
lstm_layer = tf.keras.layers.CuDNNLSTM(state_size, return_sequences=True, return_state=True, stateful=True)
lstm,state_h,state_c = lstm_layer(concat_inputs)
在培训期间,我的batch_generator
将以前的状态作为输入并进行更新。然后,在训练批次之前,我使用最新更新的状态在lstm
层中重置状态:
(batch,states) = batch_generator.next_batch(last_states,seq_length)
lstm_layer.reset_states(states)
然后我使用train_on_batch
训练批处理:
loss = model.train_on_batch(batch[0],y=batch[1])
我想不出一种方法来在每批结束时从模型中提取lstm
状态(state_h,state_c
)。
我当前正在使用一种解决方法,该方法具有一些代码异味:
model._make_train_function()
model.train_function.outputs += [state_h,state_c]
x,y,sample_weights = model._standardize_user_data(inputs, targets, None, None)
outputs = model.train_function(x + y + sample_weights)
loss = outputs[0] #loss = model.train_on_batch(inputs,y=targets)
last_states = (outputs[2],outputs[3])
有没有一种更好的方法可以实现,而无需编写自己的train_on_batch
?