RNN模型:推断长于训练期间使用的最大序列长度的句子

时间:2016-10-05 18:54:26

标签: tensorflow

我正在训练RNN模型(使用rnn.dynamic_rnn方法),我的数据矩阵的形状为num_examples x max_sequence_length x num_features。在训练期间,我不希望将max_sequence_length增加到50或100以上,因为它会增加训练时间和记忆。我的训练集中的所有句子都小于50.但是,在测试期间,我希望模型能够推断出最多500个令牌。可能吗?我该怎么做?

1 个答案:

答案 0 :(得分:0)

@sonal - 是的,这是可能的。因为,在测试的大部分时间里,我们感兴趣的是传递一个例子,而不是一堆数据。 所以,你需要的是,你需要传递单个例子的数组让我们说

test_index = [10 , 23 , 42 ,12 ,24, 50] 

到dynamic_rnn。预测必须基于最终隐藏状态。在dynamic_rnn里面,我认为你可以在训练中传递超过max_length的句子。如果不是,您可以编写自定义解码器功能,以计算GRU或LSTM状态,以及训练时获得的权重。我们的想法是,您可以继续生成输出,直到您达到测试用例的最大长度,或者直到模型生成了EOS'特殊令牌。我更喜欢,你使用解码器,在你从编码器得到最终的隐藏状态后,这也会给出更好的结果。

# function to the while-loop, for early stopping
    def decoder_cond(time, state, output_ta_t):
        return tf.less(time, max_sequence_length)

    # the body_builder is just a wrapper to parse feedback
    def decoder_body_builder(feedback=False):
        # the decoder body, this is where the RNN magic happens!
        def decoder_body(time, old_state, output_ta_t):
            # when validating we need previous prediction, handle in feedback
            if feedback:
                def from_previous():
                    prev_1 = tf.matmul(old_state, W_out) + b_out
                    a_max = tf.argmax(prev_1, 1)
                    #### Try to find the token index and stop the condition until you get a EOS token index . 
                    return tf.gather(embeddings, a_max )
                x_t = tf.cond(tf.equal(time, 0), from_previous, lambda: input_ta.read(0))
            else:
                # else we just read the next timestep
                x_t = input_ta.read(time)

            # calculate the GRU
            z = tf.sigmoid(tf.matmul(x_t, W_z_x) + tf.matmul(old_state, W_z_h) + b_z) # update gate
            r = tf.sigmoid(tf.matmul(x_t, W_r_x) + tf.matmul(old_state, W_r_h) + b_r) # reset gate
            c = tf.tanh(tf.matmul(x_t, W_c_x) + tf.matmul(r*old_state, W_c_h) + b_c) # proposed new state
            new_state = (1-z)*c + z*old_state # new state

            # writing output
            output_ta_t = output_ta_t.write(time, new_state)

            # return in "input-to-next-step" style
            return (time + 1, new_state, output_ta_t)
        return decoder_body
    # set up variables to loop with
    output_ta = tensor_array_ops.TensorArray(tf.float32, size=1, dynamic_size=True, infer_shape=False)
    time = tf.constant(0)
    loop_vars = [time, initial_state, output_ta]

    # run the while-loop for training
    _, state, output_ta = tf.while_loop(decoder_cond,
                                        decoder_body_builder(feedback = True),
                                        loop_vars,
                                        swap_memory=swap)

这只是一个代码段,尝试相应地修改它。有关详情,请参阅https://github.com/alrojo/tensorflow-tutorial