Tensorflow LSTM编码器测试代码尺寸不匹配错误

时间:2017-03-15 08:50:37

标签: python tensorflow dimensions lstm

我正在使用tensorflow编写一个lstm编码器。然后我正在编写测试代码以查看我的代码是否有效。这是我的代码:

import tensorflow as tf

class Encoder(object):
    def __init__(self, state_size,vocab_dim, FLAGS):
        self.state_size = state_size
        self.FLAGS = FLAGS
        self.vocab_dim = vocab_dim
    # Return hidden representation HQ, HP of question and paragraph respectively
    def LSTMpreprocessing(self,paragraph,question, paragraph_length,question_length):
        #Encode Question
        with tf.variable_scope("Q_encode"):
            cell = tf.nn.rnn_cell.BasicLSTMCell(self.state_size)
            HQ, _ = tf.nn.dynamic_rnn(cell,question,sequence_length = question_length, dtype = tf.float32)

        #Encode Paragraph
        with tf.variable_scope("P_encode"):
            cell = tf.nn.rnn_cell.BasicLSTMCell(self.state_size)
            HP, _ = tf.nn.dynamic_rnn(cell,paragraph,sequence_length = paragraph_length, dtype = tf.float32)
        return HQ,HP

目前我正在尝试检查LSTMpreprocessing返回的内容。为此我写了以下测试代码:

def main(_):
    paragraph_placeholder = tf.placeholder(tf.int32, (None, 4), name="paragraph_placeholder")
    question_placeholder = tf.placeholder(tf.int32, (None, 3), name="question_placeholder")
    paragraph_length = tf.placeholder(tf.int32, (None), name="paragraph_length")
    question_length = tf.placeholder(tf.int32, (None), name="question_length")
    encoder = Encoder(4,3,None)

    paragraph = [[0,1],[1,0],[0,2],[5,3]]
    question = [[3,3],[5,5],[1,1]]
    func = encoder.LSTMpreprocessing(paragraph_placeholder,question_placeholder,paragraph_length,question_length)

    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        HQ,HP = sess.run(func, feed_dict = {paragraph_placeholder :paragraph, question_placeholder : question,paragraph_length : 4, question_length : 3}) 
    print(HQ.get_shape().as_list())
    print(HP.get_shape().as_list())

当我运行上面的测试代码时,我收到以下错误:

    ValueError: Dimension must be 2 but is 3 for 
'Q_encode/transpose' (op: 'Transpose') with input shapes: [?,3], [3].

作为张量流中的新手,我完全无法弄清楚我做错了什么。有人可以帮我指出我犯的错误吗?

0 个答案:

没有答案