扩展:What is the use of a "reuse" parameter of tf.contrib.layers functions?。
问题:虽然这个问题已经在github上提出并且可能会在TensorFlow的另一个版本中得到解决,但我暂时没有找到现有的解决方案;是否存在可能在此期间起作用的权宜之计?
代码:
state_size = 4
def lstm_cell():
if 'reuse' in inspect.getargspec(tf.contrib.rnn.BasicLSTMCell.__init__).args:
return tf.contrib.rnn.BasicLSTMCell(state_size, forget_bias=0.0, state_is_tuple=True, reuse=tf.get_variable_scope().reuse)
else:
return tf.contrib.rnn.BasicLSTMCell(state_size, forget_bias=0.0, state_is_tuple=True)
cell = lstm_cell()
cell = rnn.DropoutWrapper(cell, output_keep_prob=0.5)
cell = rnn.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
states_series, current_state = tf.nn.dynamic_rnn(cell, tf.expand_dims(batchX_placeholder, -1), initial_state=rnn_tuple_state)
states_series = tf.reshape(states_series, [-1, state_size])
函数lstm_cell()是来自https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/ptb_word_lm.py的建议。它解释了最新版本的tensorflow包含了BasicLSTMCell()的'reuse'参数。
在此代码中,如果我将重用设置为False
,则tf.nn.dynamic_rnn行会产生错误:
如果我将重用设置为True
,则错误为:
最后,将'scope = None'添加到dynamic_rnn也没有任何区别。
答案 0 :(得分:0)
您是否考虑过尝试'重用为真'错误?
如果您在使用之前:MultiRNNCell([BasicLSTMCell(...)] * num_layers),更改为:MultiRNNCell([BasicLSTMCell(...)for _ in 范围(num_layers)])。
以下代码剪辑适用于我(已经回答here)
def lstm_cell():
cell = tf.contrib.rnn.NASCell(state_size, reuse=tf.get_variable_scope().reuse)
return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=0.8)
rnn_cells = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(num_layers)], state_is_tuple = True)
outputs, current_state = tf.nn.dynamic_rnn(rnn_cells, x, initial_state=rnn_tuple_state)