如何用新的填充张量就地替换torch.tensor?

时间:2020-10-03 03:36:01

标签: python-3.x pytorch tensor

我试图弄清楚我如何用一个新的torch.tensor覆盖位于dict中的torch.tensor对象,由于填充,该torch.tensor会更长一些。

# pad the tensor
zeros = torch.zeros(55).long()
zeros[zeros == 0] = 100  # change to padding
temp_input = torch.cat([batch['input_ids'][0][0], zeros], dim=-1) # cat
temp_input.shape  # [567]
batch['input_ids'][0][0].shape  # [512]
batch['input_ids'][0][0] = temp_input
# The expanded size of the tensor (512) must match the existing size (567) at non-singleton dimension 0.  Target sizes: [512].  Tensor sizes: [567]

我正在努力寻找一种方法来就地扩展张量或在尺寸变化时覆盖它们。

该字典是从火炬的DataLoader发出的,看起来像这样:

{'input_ids': tensor([[[  101,  3720,  2011,  ..., 25786,  2135,   102]],
 
         [[  101,  1017,  2233,  ...,     0,     0,     0]],
 
         [[  101,  1996,  2899,  ..., 14262, 20693,   102]],
 
         [[  101,  2197,  2305,  ...,  2000,  1996,   102]]]),
 'attn_mask': tensor([[[1, 1, 1,  ..., 1, 1, 1]],
 
         [[1, 1, 1,  ..., 0, 0, 0]],
 
         [[1, 1, 1,  ..., 1, 1, 1]],
 
         [[1, 1, 1,  ..., 1, 1, 1]]]),
 'cats': tensor([[-0.6410,  0.1481, -2.1568, -0.6976],
         [-0.4725,  0.1481, -2.1568,  0.7869],
         [-0.6410, -0.9842, -2.1568, -0.6976],
         [-0.6410, -0.9842, -2.1568, -0.6976]], grad_fn=<StackBackward>),
 'target': tensor([[1],
         [0],
         [1],
         [1]]),
 'idx': tensor([1391, 4000,  293,  830])}

0 个答案:

没有答案