Pytorch相当于`tf.reverse_sequence`?

时间:2019-04-29 14:01:28

标签: python tensorflow pytorch

我想对填充序列进行反向LSTM,这需要反转输入序列而无需填充。

对于这样的批次(其中_代表填充):

a b c _ _ _
d e f g _ _
h i j k l m

如果想得到:

c b a _ _ _
g f e d _ _
m l k j i h

TensorFlow具有函数tf.reverse_sequence,该函数获取批次中序列的输入张量和长度,然后返回反转的批次。在Pytorch中有一种简单的方法吗?

1 个答案:

答案 0 :(得分:2)

很遗憾,although it has been requested还没有直接的等效项。

我也查看了整个PackedSequence对象,但是没有定义.flip()操作。假设您已经按照建议的那样拥有必要的数据来提供长度,则可以使用以下功能实现它:

def flipBatch(data, lengths):
    assert data.shape[0] == len(lengths), "Dimension Mismatch!"
    for i in range(data.shape[0]):
        data[i,:lengths[i]] = data[i,:lengths[i]].flip(dims=[0])

    return data

不幸的是,这仅在序列为二维(使用batch_size x sequence)时才有效,但是您可以轻松地将其扩展为特定的输入要求。这已经或多或少涵盖了上面链接中的提案,但是我将其更新为今天的标准。