从每个批次中提取状态-Keras中的train_on_batch(Tensorflow)

时间:2018-11-02 00:29:44

标签: python tensorflow keras rnn

我要实现的目标

我正在尝试使用自定义批处理生成器来创建序列到序列模型:

  • 输入是长序列,无法放入单个批次
  • 序列的长度差异很大
  • 每次生成批处理时,批处理生成器都会将序列移动一定量
  • 如果任何单个序列到达末尾,则批处理生成器将用新的序列替换该批处理的特定部分
  • 我希望前一批的状态在下一批中保持不变;而所有新添加的序列都会被零状态取代

我要做什么

我的模型有一个LSTM层,该层返回序列和最终状态:

lstm_layer = tf.keras.layers.CuDNNLSTM(state_size, return_sequences=True, return_state=True, stateful=True)
lstm,state_h,state_c = lstm_layer(concat_inputs)

在培训期间,我的batch_generator将以前的状态作为输入并进行更新。然后,在训练批次之前,我使用最新更新的状态在lstm层中重置状态:

(batch,states) = batch_generator.next_batch(last_states,seq_length)
lstm_layer.reset_states(states)

然后我使用train_on_batch训练批处理:

loss = model.train_on_batch(batch[0],y=batch[1])

问题

我想不出一种方法来在每批结束时从模型中提取lstm状态(state_h,state_c)。

我的解决方法

我当前正在使用一种解决方法,该方法具有一些代码异味:

model._make_train_function()
model.train_function.outputs += [state_h,state_c]
x,y,sample_weights = model._standardize_user_data(inputs, targets, None, None)
outputs = model.train_function(x + y + sample_weights)
loss = outputs[0] #loss = model.train_on_batch(inputs,y=targets)
last_states = (outputs[2],outputs[3])

有没有一种更好的方法可以实现,而无需编写自己的train_on_batch

0 个答案:

没有答案