Keras model.reset_states()与tf.train.MonitoredTrainingSession不兼容

时间:2019-06-09 12:41:00

标签: tensorflow session keras reset

我想使用tf.train.MonitoredTrainingSession()来训练Keras中描述的模型。该模型是有状态的模型,因此我想在每个时期后重置状态。

一个问题是,如果我致电model.reset_states(),它将产生以下错误。

RuntimeError:图形已完成,无法修改。

如果使用tf.Session()代替tf.train.MonitoredTrainingSession(),则不会出现此错误。

例如,在下面的示例代码中,即使它不是培训代码,也会生成相同的错误消息。

#!/usr/bin/python

import tensorflow as tf


inputs1 = tf.reshape(tf.linspace(0.0, 100.0, 10), (1, 2, 5)) 
inputs2 = tf.reshape(tf.linspace(100.0, 0.0, 10), (1, 2, 5)) 

model = tf.keras.Sequential([
    tf.keras.layers.LSTM(
    5,  
    return_sequences=True, stateful=True)
])

outputs1 = model(inputs1)
outputs2 = model(inputs2)

with tf.train.MonitoredTrainingSession() as sess:
  model.reset_states()
  print (sess.run(outputs1))
  model.reset_states()
  print (sess.run(outputs2))

我找到了两种方法来解决此问题:

  1. 在重置统计信息之前要使用tf.get_current_graph()._unsafe_unfinalize()

  2. 使用tf.Session()而不是tf.train.MonitoedTrainingSession()

但是我认为这两种方法都不理想。 您能否提出这种情况下最好的解决方案?

0 个答案:

没有答案