给定大小为n x 2A x B x C
的输入张量,如何将其拆分为两个张量,每个张量为n x A x B x C
?基本上,n
是批量大小。
答案 0 :(得分:1)
您可以使用torch.split
:
torch.split(input_tensor, split_size_or_sections=A, dim=1)
答案 1 :(得分:0)
我认为你可以这样做:
tensor_a = torch.Tensor(n, 2A, B,C)
-- Initialize tensor_a with the data
tensor_b = torch.Tensor(n, A, B, C)
tensor_b = tensor_a[{{},1,{},{}}]
tensor_c = torch.Tensor(n, A, B, C)
tensor_c = tensor_a[{{},2,{},{}}]