使用word2vec作为张量流的输入的LSTM的可变句子长度

时间:2018-09-27 07:33:13

标签: python tensorflow lstm word2vec

我正在使用word2vec作为输入来构建LSTM模型。我正在使用tensorflow框架。我已经完成了词的嵌入部分,但是我坚持使用LSTM部分。

这里的问题是我的句子长度不同,这意味着我必须填充或使用具有指定序列长度的dynamic_rnn。我正在和他们两个奋斗。

  1. 填充。 填充的混乱之处在于我何时进行填充。我的模特就像

    word_matrix = model.wv.syn0
    X = tf.placeholder(tf.int32,shape)
    数据= tf.placeholder(tf.float32,形状)
    数据= tf.nn.embedding_lookup(word_matrix,X)

然后,我正在将word_matrix的单词索引序列馈入X。我担心如果将零填充到馈入X的序列中,那么我会错误地继续馈入不必要的输入(在这种情况下,word_matrix [0])。

所以,我想知道0填充的正确方法是什么。如果您让我知道如何使用张量流实现它,那就太好了。

  1. dynamic_rnn 为此,我声明了一个包含所有句子长度的列表,并将其与X和y一起结尾。在这种情况下,我不能批量输入输入。然后,我遇到了此错误(在未知的TensorShape上未定义ValueError:as_list()。)在我看来,sequence_length参数仅接受列表? (不过,我的想法可能完全不正确。)

以下是我的代码。

X = tf.placeholder(tf.int32)
labels = tf.placeholder(tf.int32, [None, numClasses])
length = tf.placeholder(tf.int32)

data = tf.placeholder(tf.float32, [None, None, numDimensions])
data = tf.nn.embedding_lookup(word_matrix, X)

lstmCell = tf.contrib.rnn.BasicLSTMCell(lstmUnits, state_is_tuple=True)
lstmCell = tf.contrib.rnn.DropoutWrapper(cell=lstmCell, output_keep_prob=0.25)
initial_state=lstmCell.zero_state(batchSize, tf.float32)
value, _ = tf.nn.dynamic_rnn(lstmCell, data, sequence_length=length,
                             initial_state=initial_state, dtype=tf.float32)

我在这部分上很挣扎,以至于任何帮助将不胜感激。

谢谢。

1 个答案:

答案 0 :(得分:2)

Tensorflow不支持可变长度的Tensor。因此,当您声明张量时,list / numpy数组应具有统一的形状。

  1. 从第一部分开始,我了解到您已经能够在序列长度的最后一个时间步中填充零。理想情况应该是这样。查找批量大小为4,最大序列长度10和50个隐藏单位->

    的批次的方法如下

    [4,10,50]将是整个批次的大小,但是在内部,当您尝试可视化填充物时,它的形状可能像这样->

    `[[5+5pad,50],[10,50],[8+2pad,50],[9+1pad,50]`
    

    每个填充代表隐藏长度为50张量的序列长度为1。除零外,其他所有内容均已填充。查看UI LINK HEREthis question,以了解有关如何手动填充的更多信息。

  2. 您将使用动态rnn的确切原因是,您不想在填充序列上对其进行计算。 this one api将通过传递sequence_length参数来确保这一点。

    对于上面的示例,该参数将为:[5,10,8,9]对于上面的示例。您可以通过将每个批处理组件的非零实体相加来计算它。一种简单的计算方法是:

    data_mask = tf.cast(data, tf.bool)
    data_len = tf.reduce_sum(tf.cast(data_mask, tf.int32), axis=1)
    

    并将其通过tf.nn.dynamic_rnn api:

    tf.nn.dynamic_rnn(lstmCell, data, sequence_length=data_len, initial_state=initial_state)