我正在运行一个用于语言建模的基本lstm代码。
但我不想做BPTT
。我想做tf.stop_gradient(state)
with tf.variable_scope("RNN"):
for time_step in range(N):
if time_step > 0: tf.get_variable_scope().reuse_variables()
(cell_output, state) = cell(inputs[:, time_step, :], state)
但是,state
是LSTMStateTuple
,所以我尝试了:
for lli in range(len(state)):
print(state[lli].c, state[lli].h)
state[lli].c = tf.stop_gradient(state[lli].c)
state[lli].h = tf.stop_gradient(state[lli].h)
但我收到AttributeError: can't set attribute
错误:
File "/home/liyu-iri/IRRNNL/word-rnn/ptb/models/decoupling.py", line 182, in __init__
state[lli].c = tf.stop_gradient(state[lli].c)
AttributeError: can't set attribute
我还尝试使用tf.assign
,但state[lli].c
不是变量。
所以,我想知道如何阻止LSTMStateTuple
的渐变?
或者,我怎么能阻止BPTT?我只想在单帧中做BP。
非常感谢!
答案 0 :(得分:0)
我认为这是一个纯粹的python问题:LSTMStateTuple只是一个collections.namedtuple而python不允许你在那里分配元素(就像在其他元组中一样)。解决方案是创建一个全新的,例如就像在stopped_state = LSTMStateTuple(tf.stop_gradient(old_tuple.c), tf.stop_gradient(old_tuple.h))
中一样,然后使用它(或那些列表)作为状态。如果你坚持要替换现有的元组,我认为namedtuple有一个_replace方法,请参阅here,如old_tuple._replace(c=tf.stop_gradient(...))
。希望有所帮助!