Tensorflow:seq2seq注意,尺寸不匹配

时间:2018-07-24 21:56:17

标签: tensorflow

尝试在Tensorflow中实现具有双向RNN编码,波束搜索推理和注意力的编解码器模型。前两个工作时,我在AttentionWrapper上遇到了麻烦。我当前遇到以下错误:

InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [8,50] vs. shape[1] = [1028,50]
 [[Node: decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](decoder/while/Identity_12, decoder/while/Identity_8, decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat/axis)]]

这是我的编码层:

def encoding_layer(self):
    rnn_inputs = self.embedded_encoder_inputs
    for i in range(self.num_layers):
        forward_cell = tf.nn.rnn_cell.LSTMCell(
            self.hidden_dim, 
            initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2*i))
        backward_cell = tf.nn.rnn_cell.LSTMCell(
            self.hidden_dim, 
            initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=(2*i+1)))

        (forward_output, backward_output), final_state = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=forward_cell,
            cell_bw=backward_cell,
            inputs=rnn_inputs,
            sequence_length=self.encoder_lengths,
            scope='BLSTM_' + str(i),
            dtype=tf.float32)

        rnn_inputs = tf.concat([forward_output, backward_output], axis=2)

    self.encoder_outputs = rnn_inputs
    self.encoder_final_state = final_state

这是我的解码培训:

def decoding_training(self):
    attention_mechanism = tf.contrib.seq2seq.LuongAttention(
        num_units=self.hidden_dim,
        memory=self.encoder_outputs,
        memory_sequence_length=self.encoder_lengths,
        dtype=tf.float32)

    self.decoder_cell = tf.contrib.rnn.MultiRNNCell([
        tf.nn.rnn_cell.LSTMCell(
        self.hidden_dim, 
        initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) for _ in range(self.num_layers)])

    self.decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
        cell=self.decoder_cell,
        attention_mechanism=attention_mechanism,
        attention_layer_size=self.hidden_dim)
    decoder_initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32).clone(cell_state=self.encoder_final_state)

    self.output_layer = Dense(
        units=self.vocab_size,
        kernel_initializer=tf.truncated_normal_initializer(mean = 0.0, stddev=0.1))

    training_helper = tf.contrib.seq2seq.TrainingHelper(
        inputs=self.embedded_decoder_inputs,
        sequence_length=self.decoder_lengths,
        time_major=True)

    training_decoder = tf.contrib.seq2seq.BasicDecoder(
        cell=self.decoder_cell,
        helper=training_helper,
        initial_state=decoder_initial_state,
        output_layer=self.output_layer)

    training_outputs = tf.contrib.seq2seq.dynamic_decode(
        decoder=training_decoder,
        impute_finished=True,
        maximum_iterations=self.max_output_length)[0]

    self.training_logits = training_outputs.rnn_output

最后,这是解码推断:

start_tokens = tf.tile(
        input=tf.constant([self.vocab.index("<START>")], dtype=tf.int32),
        multiples=[self.batch_size])

    # decoder_initial_state = tf.contrib.seq2seq.tile_batch(self.encoder_final_state, multiplier=self.beam_width)
    decoder_initial_state = self.decoder_cell.zero_state(self.batch_size * self.beam_width, tf.float32).clone(cell_state=self.encoder_final_state)

    inference_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
        cell=self.decoder_cell,
        embedding=self.embedding_space,
        start_tokens=start_tokens,
        end_token=self.vocab.index("<END>"),
        initial_state=decoder_initial_state,
        beam_width=self.beam_width,
        output_layer=self.output_layer,
        length_penalty_weight=0.0)

    inference_decoder_output = tf.contrib.seq2seq.dynamic_decode(
        inference_decoder,
        impute_finished=False,
        maximum_iterations=self.max_output_length)[0]

    self.inference_logits = inference_decoder_output.predicted_ids[:,:,0]

此代码部分基于Tensorflow NMT documentation。这是完整的堆栈跟踪:

2018-07-24 14:51:53.891498: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
WARNING:tensorflow:From /usr/local/lib/python3.6/site-packages/tensorflow/python/ops/rnn.py:417: calling reverse_sequence (from tensorflow.python.ops.array_ops) with seq_dim is deprecated and will be removed in a future version.
Instructions for updating:
seq_dim is deprecated, use seq_axis instead
WARNING:tensorflow:From /usr/local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:432: calling reverse_sequence (from tensorflow.python.ops.array_ops) with batch_dim is deprecated and will be removed in a future version.
Instructions for updating:
batch_dim is deprecated, use batch_axis instead
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1322, in _do_call
    return fn(*args)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1307, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1409, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [8,50] vs. shape[1] = [1028,50]
     [[Node: decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](decoder/while/Identity_12, decoder/while/Identity_8, decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat/axis)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "advanced_encoder_decoder.py", line 194, in <module>
    simple_example()
  File "advanced_encoder_decoder.py", line 183, in simple_example
    seq2seq.fit(X, y)
  File "../tf_model_base.py", line 128, in fit
    feed_dict=self.train_dict(X_batch, y_batch))
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 900, in run
    run_metadata_ptr)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1135, in _run
    feed_dict_tensor, options, run_metadata)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1316, in _do_run
    run_metadata)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1335, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [8,50] vs. shape[1] = [1028,50]
     [[Node: decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](decoder/while/Identity_12, decoder/while/Identity_8, decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat/axis)]]

Caused by op 'decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat', defined at:
  File "advanced_encoder_decoder.py", line 194, in <module>
    simple_example()
  File "advanced_encoder_decoder.py", line 183, in simple_example
    seq2seq.fit(X, y)
  File "../tf_model_base.py", line 113, in fit
    self.build_graph()
  File "../tf_encoder_decoder.py", line 42, in build_graph
    self.decoding_layer()
  File "../tf_encoder_decoder.py", line 119, in decoding_layer
    self.decoding_training()
  File "advanced_encoder_decoder.py", line 125, in decoding_training
    maximum_iterations=self.max_output_length)[0]
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 323, in dynamic_decode
    swap_memory=swap_memory)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3209, in while_loop
    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2941, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2878, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3179, in <lambda>
    body = lambda i, lv: (i + 1, orig_body(*lv))
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 266, in body
    decoder_finished) = decoder.step(time, inputs, state)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py", line 137, in step
    cell_outputs, cell_state = self._cell(inputs, state)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 232, in __call__
    return super(RNNCell, self).__call__(inputs, state)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 329, in __call__
    outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 703, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py", line 1410, in call
    cell_inputs = self._cell_input_fn(inputs, state.attention)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py", line 1187, in <lambda>
    lambda inputs, attention: array_ops.concat([inputs, attention], -1))
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 1113, in concat
    return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 1029, in concat_v2
    "ConcatV2", values=values, axis=axis, name=name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3414, in create_op
    op_def=op_def)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1740, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [8,50] vs. shape[1] = [1028,50]
     [[Node: decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](decoder/while/Identity_12, decoder/while/Identity_8, decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat/axis)]]

0 个答案:

没有答案