我想使用tf.train.MonitoredTrainingSession()
来训练Keras中描述的模型。该模型是有状态的模型,因此我想在每个时期后重置状态。
一个问题是,如果我致电model.reset_states()
,它将产生以下错误。
RuntimeError:图形已完成,无法修改。
如果使用tf.Session()
代替tf.train.MonitoredTrainingSession()
,则不会出现此错误。
例如,在下面的示例代码中,即使它不是培训代码,也会生成相同的错误消息。
#!/usr/bin/python
import tensorflow as tf
inputs1 = tf.reshape(tf.linspace(0.0, 100.0, 10), (1, 2, 5))
inputs2 = tf.reshape(tf.linspace(100.0, 0.0, 10), (1, 2, 5))
model = tf.keras.Sequential([
tf.keras.layers.LSTM(
5,
return_sequences=True, stateful=True)
])
outputs1 = model(inputs1)
outputs2 = model(inputs2)
with tf.train.MonitoredTrainingSession() as sess:
model.reset_states()
print (sess.run(outputs1))
model.reset_states()
print (sess.run(outputs2))
我找到了两种方法来解决此问题:
在重置统计信息之前要使用tf.get_current_graph()._unsafe_unfinalize()
。
使用tf.Session()
而不是tf.train.MonitoedTrainingSession()
。
但是我认为这两种方法都不理想。 您能否提出这种情况下最好的解决方案?