如何在pytorch中批处理对话框数据集?

时间:2019-11-19 08:19:33

标签: pytorch

我想做一个面向任务的对话聊天机器人,用于预订餐厅,因为每个对话都有不同的顺序(例如,有些对话有5轮对话,即10句,而另一些对话有6轮对话,即12句。完全),我不知道如何批处理数据集。

能给我一些教程或github示例吗?

1 个答案:

答案 0 :(得分:1)

在Stackoverflow上有一些与此相关的问题。我喜欢here提供的说明/答案。 tldr版本将使用Packed SequenceThe answer I linked to提供了以下示例(从链接复制):

a = [torch.tensor([1,2,3]), torch.tensor([3,4])]
b = torch.nn.utils.rnn.pad_sequence(a, batch_first=True)
>>>>
tensor([[ 1,  2,  3],
    [ 3,  4,  0]])
torch.nn.utils.rnn.pack_padded_sequence(b, batch_first=True, lengths=[3,2])
>>>>PackedSequence(data=tensor([ 1,  3,  2,  4,  3]), batch_sizes=tensor([ 2,  2,  1]))