LSTM遵循均值池(TensorFlow)

时间:2017-09-04 07:39:40

标签: tensorflow lstm pooling rnn

我知道在LSTM Followed by Mean Pooling有一个类似的话题,但那是关于Keras的,我在纯TensorFlow中工作。

我有一个LSTM网络,其中重现由:

处理
outputs, final_state = tf.nn.dynamic_rnn(cell,
                                         embed,
                                         sequence_length=seq_lengths,
                                         initial_state=initial_state)

我为每个样本传递正确的序列长度(用零填充)。在任何情况下,输出都包含不相关的输出,因为根据序列长度,某些样本产生的输出比其他样本长。

现在我正通过以下方法提取最后的相关输出:

def extract_axis_1(data, ind):
    """
    Get specified elements along the first axis of tensor.
    :param data: Tensorflow tensor that will be subsetted.
    :param ind: Indices to take (one for each element along axis 0 of data).
    :return: Subsetted tensor.
    """

    batch_range = tf.range(tf.shape(data)[0])
    indices = tf.stack([batch_range, ind], axis=1)
    res = tf.reduce_mean(tf.gather_nd(data, indices), axis=0)

我将sequence_length - 1作为索引传递。在参考最后一个主题时,我想选择所有相关的输出,然后选择平均合并,而不是最后一个。

现在,我尝试将嵌套列表作为extract_axis_1的替代值传递,但tf.stack不接受此操作。

有解决方案吗?

1 个答案:

答案 0 :(得分:0)

您可以利用weight功能的tf.contrib.seq2seq.sequence_loss参数。

来自文档:

  

weights:形状张力[batch_sizesequence_length]和dtype float。权重构成序列中每个预测的权重。使用weights作为屏蔽时,将所有有效时间步长设置为1,将所有填充时间步长设置为0,例如由tf.sequence_mask返回的掩码。

您需要计算二进制掩码,以区分有效输出和无效输出。然后你可以将这个掩码提供给损失函数的weights参数(可能,你会想要使用像这样的丢失);在计算损失时,函数不会考虑具有0权重的输出。

如果您不能使用序列丢失,您可以手动执行完全相同的操作。您计算二进制掩码,然后将输出乘以此掩码,并将其作为输入提供给完全连接的层。