我给定了形状为in
的2d张量a x b
,如下所示(其中a = 9
和A1
,A2
,..., C2
代表b
维向量):
此外,我有一个lengths
数组,其中sum(lengths) = a
,每个条目都是一个正整数:
然后我想获得一个3d输出张量out
,其中lengths[0]
的前in
个条目形成第一行,{{1}的下一个lengths[1]
个条目形成1}}形成第二行,依此类推。也就是说,输出张量应具有in
的形状,并填充零(下图中的每个len(lengths) x max(lengths) x b
代表一个0
维零向量):
由于这是使用反向传播训练的神经网络的一部分,因此所有使用的操作必须是可区分的。使用PyTorch如何做到这一点(理想情况下,具有良好的性能)?
答案 0 :(得分:1)
您可以使用下面的功能。它是可区分的,并且可以与反向传播一起使用。
def sequence_to_padding(x, length):
# declare the shape, it can work for x of any shape.
ret_tensor = torch.zeros((length.shape[0], torch.max(length)) + tuple(x.shape[1:]))
cum_len = 0
for i, l in enumerate(length):
ret_tensor[i, :l] = x[cum_len: cum_len+l]
cum_len += l
return ret_tensor
示例:
in_vector = torch.rand((9,1))
#tensor([[0.3545],
# [0.5443],
# [0.7550],
# [0.9624],
# [0.9250],
# [0.8035],
# [0.6877],
# [0.4186],
# [0.4199]])
lengths = torch.tensor([3, 4, 2])
sequence_to_padding(in_vector, lengths)
#tensor([[[0.3545],
# [0.5443],
# [0.7550],
# [0.0000]],
#
# [[0.9624],
# [0.9250],
# [0.8035],
# [0.6877]],
#
# [[0.4186],
# [0.4199],
# [0.0000],
# [0.0000]]])
答案 1 :(得分:1)
这是我使用torch.nn.utils.rnn.pad_sequence()
的实现:
in_tensor = torch.rand((9, 3))
print(in_tensor)
print(36*'=')
lengths = torch.tensor([3, 4, 2])
cum_len = 0
y = []
for idx, val in enumerate(lengths):
y.append(in_tensor[cum_len : cum_len+val])
cum_len += val
print(torch.nn.utils.rnn.pad_sequence(y, batch_first=True)))
输出:
# in_tensor of shape (9 x 3)
tensor([[0.9169, 0.3549, 0.6211],
[0.4832, 0.5475, 0.8862],
[0.8708, 0.5462, 0.9374],
[0.4605, 0.1167, 0.5842],
[0.1670, 0.2862, 0.0378],
[0.2438, 0.5742, 0.4907],
[0.1045, 0.5294, 0.5262],
[0.0805, 0.2065, 0.2080],
[0.6417, 0.4479, 0.0688]])
====================================
# out tensor of shape (len(lengths) x max(lengths) x b), in this case b is 3
tensor([[[0.9169, 0.3549, 0.6211],
[0.4832, 0.5475, 0.8862],
[0.8708, 0.5462, 0.9374],
[0.0000, 0.0000, 0.0000]],
[[0.4605, 0.1167, 0.5842],
[0.1670, 0.2862, 0.0378],
[0.2438, 0.5742, 0.4907],
[0.1045, 0.5294, 0.5262]],
[[0.0805, 0.2065, 0.2080],
[0.6417, 0.4479, 0.0688],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000]]])