我正在尝试使用' tf.contrib.rnn.LayerNormBasicLSTMCell'缩短训练LSTM细胞所需的时间。
在使用LayerNormBasicLSTMCell之前,我使用了&#tf.contrib.rnn.BasicLSTMCell',以下代码非常正常。
class Model():
def __init__(self, args, infer=False):
''' Some definitions are here '''
cell_fn = tf.contrib.rnn.BasicLSTMCell
def get_cell():
return cell_fn(args.rnn_size)
cell = get_cell()
if (infer == False and args.keep_prob < 1): # training mode
cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=args.keep_prob)
self.cell = cell
# Cell states : batch_size x (1, cell.state_size)
zero_state = tf.split(tf.zeros([args.batch_size, cell.state_size]), axis=0, num_or_size_splits=args.batch_size)
self.state_in = tf.identity(zero_state, name='state_in')
self.state_out = tf.split(tf.zeros([args.batch_size, cell.state_size]), axis=0, num_or_size_splits=args.batch_size)
self.output_states = tf.split(tf.zeros([args.batch_size, args.rnn_size]), axis=0, num_or_size_splits=args.batch_size)
''' Some definitions are here '''
for b in range(args.batch_size):
# current embedding sequence : (seq_length x rnn_size)
current_emd_seq = embedding_seqs[b]
for f in range(args.seq_length):
# current embedding frame : (1 x rnn_size)
current_emd_frame = tf.reshape(current_emd_seq[f], shape=(1, args.rnn_size))
with tf.variable_scope("rnnlm") as scope:
if (b > 0 or f > 0):
scope.reuse_variables()
# go through LSTM cell
self.output_states[b], zero_state[b] = cell(current_emd_frame, zero_state[b])
但是,如果我刚从
更改了代码行cell_fn = tf.contrib.rnn.BasicLSTMCell
到
cell_fn = tf.contrib.rnn.LayerNormBasicLSTMCell
发生以下错误......
Traceback (most recent call last):
File "/home/dooseop/Projects/__IRLLSTM__/version20/kitti_train.py",
line 192, in <module>
main()
File "/home/dooseop/Projects/__IRLLSTM__/version20/kitti_train.py", line 95, in main
train(args)
File "/home/dooseop/Projects/__IRLLSTM__/version20/kitti_train.py", line 112, in train
model = Model(args)
File "/home/dooseop/Projects/__IRLLSTM__/version20/kitti_model.py", line 193, in __init__
self.output_states[b], zero_state[b] = cell(current_emd_frame, zero_state[b])
File "/home/dooseop/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1099, in __call__
output, new_state = self._cell(inputs, state, scope=scope)
File "/home/dooseop/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 232, in __call__
return super(RNNCell, self).__call__(inputs, state)
File "/home/dooseop/anaconda3/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 714, in __call__
outputs = self.call(inputs, *args, **kwargs)
File "/home/dooseop/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/rnn/python/ops/rnn_cell.py", line 1431, in
call
c, h = state
File "/home/dooseop/anaconda3/lib/python3.6/site-
packages/tensorflow/python/framework/ops.py", line 401, in __iter__
"Tensor objects are not iterable when eager execution is not "
TypeError: Tensor objects are not iterable when eager execution is not
enabled. To iterate over this tensor use tf.map_fn.
我也尝试过改变zero_state,self.state_out的大小,结果是一样的。
例如,更改
zero_state = tf.split(tf.zeros([args.batch_size, cell.state_size]), axis=0, num_or_size_splits=args.batch_size)
到
zero_state = tf.split(tf.zeros([args.batch_size, args.rnn_size, 2]), axis=0, num_or_size_splits=args.batch_size)
或
zero_state = tf.split(tf.zeros([args.batch_size, 2* args.rnn_size]), axis=0, num_or_size_splits=args.batch_size)
请告诉我使用方法的问题...谢谢!