尝试在Tensorflow中实现具有双向RNN编码,波束搜索推理和注意力的编解码器模型。前两个工作时,我在AttentionWrapper
上遇到了麻烦。我当前遇到以下错误:
InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [8,50] vs. shape[1] = [1028,50]
[[Node: decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](decoder/while/Identity_12, decoder/while/Identity_8, decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat/axis)]]
这是我的编码层:
def encoding_layer(self):
rnn_inputs = self.embedded_encoder_inputs
for i in range(self.num_layers):
forward_cell = tf.nn.rnn_cell.LSTMCell(
self.hidden_dim,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2*i))
backward_cell = tf.nn.rnn_cell.LSTMCell(
self.hidden_dim,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=(2*i+1)))
(forward_output, backward_output), final_state = tf.nn.bidirectional_dynamic_rnn(
cell_fw=forward_cell,
cell_bw=backward_cell,
inputs=rnn_inputs,
sequence_length=self.encoder_lengths,
scope='BLSTM_' + str(i),
dtype=tf.float32)
rnn_inputs = tf.concat([forward_output, backward_output], axis=2)
self.encoder_outputs = rnn_inputs
self.encoder_final_state = final_state
这是我的解码培训:
def decoding_training(self):
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
num_units=self.hidden_dim,
memory=self.encoder_outputs,
memory_sequence_length=self.encoder_lengths,
dtype=tf.float32)
self.decoder_cell = tf.contrib.rnn.MultiRNNCell([
tf.nn.rnn_cell.LSTMCell(
self.hidden_dim,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) for _ in range(self.num_layers)])
self.decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
cell=self.decoder_cell,
attention_mechanism=attention_mechanism,
attention_layer_size=self.hidden_dim)
decoder_initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32).clone(cell_state=self.encoder_final_state)
self.output_layer = Dense(
units=self.vocab_size,
kernel_initializer=tf.truncated_normal_initializer(mean = 0.0, stddev=0.1))
training_helper = tf.contrib.seq2seq.TrainingHelper(
inputs=self.embedded_decoder_inputs,
sequence_length=self.decoder_lengths,
time_major=True)
training_decoder = tf.contrib.seq2seq.BasicDecoder(
cell=self.decoder_cell,
helper=training_helper,
initial_state=decoder_initial_state,
output_layer=self.output_layer)
training_outputs = tf.contrib.seq2seq.dynamic_decode(
decoder=training_decoder,
impute_finished=True,
maximum_iterations=self.max_output_length)[0]
self.training_logits = training_outputs.rnn_output
最后,这是解码推断:
start_tokens = tf.tile(
input=tf.constant([self.vocab.index("<START>")], dtype=tf.int32),
multiples=[self.batch_size])
# decoder_initial_state = tf.contrib.seq2seq.tile_batch(self.encoder_final_state, multiplier=self.beam_width)
decoder_initial_state = self.decoder_cell.zero_state(self.batch_size * self.beam_width, tf.float32).clone(cell_state=self.encoder_final_state)
inference_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
cell=self.decoder_cell,
embedding=self.embedding_space,
start_tokens=start_tokens,
end_token=self.vocab.index("<END>"),
initial_state=decoder_initial_state,
beam_width=self.beam_width,
output_layer=self.output_layer,
length_penalty_weight=0.0)
inference_decoder_output = tf.contrib.seq2seq.dynamic_decode(
inference_decoder,
impute_finished=False,
maximum_iterations=self.max_output_length)[0]
self.inference_logits = inference_decoder_output.predicted_ids[:,:,0]
此代码部分基于Tensorflow NMT documentation。这是完整的堆栈跟踪:
2018-07-24 14:51:53.891498: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
WARNING:tensorflow:From /usr/local/lib/python3.6/site-packages/tensorflow/python/ops/rnn.py:417: calling reverse_sequence (from tensorflow.python.ops.array_ops) with seq_dim is deprecated and will be removed in a future version.
Instructions for updating:
seq_dim is deprecated, use seq_axis instead
WARNING:tensorflow:From /usr/local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:432: calling reverse_sequence (from tensorflow.python.ops.array_ops) with batch_dim is deprecated and will be removed in a future version.
Instructions for updating:
batch_dim is deprecated, use batch_axis instead
Traceback (most recent call last):
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1322, in _do_call
return fn(*args)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1307, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1409, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [8,50] vs. shape[1] = [1028,50]
[[Node: decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](decoder/while/Identity_12, decoder/while/Identity_8, decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat/axis)]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "advanced_encoder_decoder.py", line 194, in <module>
simple_example()
File "advanced_encoder_decoder.py", line 183, in simple_example
seq2seq.fit(X, y)
File "../tf_model_base.py", line 128, in fit
feed_dict=self.train_dict(X_batch, y_batch))
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 900, in run
run_metadata_ptr)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1135, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1316, in _do_run
run_metadata)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1335, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [8,50] vs. shape[1] = [1028,50]
[[Node: decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](decoder/while/Identity_12, decoder/while/Identity_8, decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat/axis)]]
Caused by op 'decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat', defined at:
File "advanced_encoder_decoder.py", line 194, in <module>
simple_example()
File "advanced_encoder_decoder.py", line 183, in simple_example
seq2seq.fit(X, y)
File "../tf_model_base.py", line 113, in fit
self.build_graph()
File "../tf_encoder_decoder.py", line 42, in build_graph
self.decoding_layer()
File "../tf_encoder_decoder.py", line 119, in decoding_layer
self.decoding_training()
File "advanced_encoder_decoder.py", line 125, in decoding_training
maximum_iterations=self.max_output_length)[0]
File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 323, in dynamic_decode
swap_memory=swap_memory)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3209, in while_loop
result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2941, in BuildLoop
pred, body, original_loop_vars, loop_vars, shape_invariants)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2878, in _BuildLoop
body_result = body(*packed_vars_for_body)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3179, in <lambda>
body = lambda i, lv: (i + 1, orig_body(*lv))
File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 266, in body
decoder_finished) = decoder.step(time, inputs, state)
File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py", line 137, in step
cell_outputs, cell_state = self._cell(inputs, state)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 232, in __call__
return super(RNNCell, self).__call__(inputs, state)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 329, in __call__
outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 703, in __call__
outputs = self.call(inputs, *args, **kwargs)
File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py", line 1410, in call
cell_inputs = self._cell_input_fn(inputs, state.attention)
File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py", line 1187, in <lambda>
lambda inputs, attention: array_ops.concat([inputs, attention], -1))
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 1113, in concat
return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 1029, in concat_v2
"ConcatV2", values=values, axis=axis, name=name)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3414, in create_op
op_def=op_def)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1740, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [8,50] vs. shape[1] = [1028,50]
[[Node: decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](decoder/while/Identity_12, decoder/while/Identity_8, decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat/axis)]]