如何在不使用python索引的情况下切片割炬张量

时间:2019-09-15 13:55:26

标签: pytorch

我下面的pytorch代码不断收到jit跟踪器警告(在pytorch 1.1.0环境中),抱怨“ Pytorch 1.0跟踪器警告:将张量转换为Python索引可能...”

是否可以在不使用python索引的情况下实现下面标记为(A)的代码行?

N,C,H,W = input.size()
Cout=4*C
Hout=H//2
Wout=W//2
downsampled=torch.zeros([N,Cout,Hout,Wout], dtype= torch.FloatTensor)
downsampled[:,1:Cout:4,:,:]=input[:,:,0::2,1::2] ---- (A)

1 个答案:

答案 0 :(得分:0)

我确认jit跟踪器不再抱怨Pytorch 1.2中的python索引(如Umang Gupta所说)。

顺便说一句,我想出了一个没有切片的实现(但仍使用索引),如下所示:

import torch

input=torch.arange(100)
input=input.view(10,10)
input=input[None, None, ...].expand(2,3,10,10) #torch.Size([2,3,10,10])

N,C,H,W=input.size()
Cout=4*C
Hout=H//2
Wout=W//2

downsampled=torch.zeros([N,Cout,Hout,Wout],dtype=torch.int8) #torch.Size([2,12,5,5])

dim2_idx=torch.tensor([k for k in range(0,H,2)])
dim3_idx=torch.tensor([k for k in range(1,W,2)])
sliced_input=input.index_select(2,dim2_idx).index_select(3,dim3_idx) #torch.Size([2,3,5,5])

#downsampled.index_select(1,torch.tensor([k for k in range(1,Cout,4)]))=temp <---Error: Can't assign to function call

for idx in range(1,Cout,4):
    downsampled[:,idx,:,:]=sliced_input[:,idx//4,:,:]