与GPU配合使用时,packed_pa​​dded_sequence会给出错误

时间:2019-01-25 02:42:17

标签: python deep-learning gpu pytorch

我正在尝试建立一个能够利用GPU的RNN,但是packed_pa​​dded_sequence给了我一个

RuntimeError: 'lengths' argument should be a 1D CPU int64 tensor

这是我指导GPU计算的方式

parser = argparse.ArgumentParser(description='Trainer')
parser.add_argument('--disable-cuda', action='store_true',
                    help='Disable CUDA')
args = parser.parse_args()
args.device = None
if not args.disable_cuda and torch.cuda.is_available():
    args.device = torch.device('cuda')
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
    args.device = torch.device('cpu')

这是代码的相关部分。

def Tensor_length(track):
    """Finds the length of the non zero tensor"""
    return int(torch.nonzero(track).shape[0] / track.shape[1])
.
.
.
def forward(self, tracks, leptons):
        self.rnn.flatten_parameters()
        # list of event lengths
        n_tracks = torch.tensor([Tensor_length(tracks[i])
                                 for i in range(len(tracks))])
        sorted_n, indices = torch.sort(n_tracks, descending=True)
        sorted_tracks = tracks[indices].to(args.device)
        sorted_leptons = leptons[indices].to(args.device)
        # import pdb; pdb.set_trace()
        output, hidden = self.rnn(pack_padded_sequence(sorted_tracks,
                                                       lengths=sorted_n.cpu().numpy(),
                                                       batch_first=True)) # this gives the error

        combined_out = torch.cat((sorted_leptons, hidden[-1]), dim=1)
        out = self.fc(combined_out)  # add lepton data to the matrix
        out = self.softmax(out)
        return out, indices  # passing indices for reorganizing truth

我已经尝试了一切,从强制将sorted_n转换为长张量到将其列为列表,但我似乎都遇到了同样的错误。 之前我还没有在pytorch上使用过gpu,并且任何建议都将不胜感激。

谢谢!

1 个答案:

答案 0 :(得分:0)

我认为您正在使用 GPU ,并且可能正在Google Colab上使用。检查您的设备

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

您可以通过降低割炬版本来解决此错误,如果您使用的是colab,则以下命令将为您提供帮助:

!pip install torch==1.6.0 torchvision==0.7.0

一旦您将割炬降级,此填充长度错误就会消失。