pytorch DataLoader中的管道损坏

时间:2018-06-22 15:19:06

标签: python pytorch

我试图了解DataLoader的工作方式。

这是我的应用方式:

# DATASET
class Word2VecDataset(torch_data.Dataset):

def __init__(self, vocabulary):
    super(Word2VecDataset, self).__init__()
    self.data_list = []
    self.vocab = vocabulary
    self.generate_batch_list()

def __getitem__(self, index):
    return self.data_list[index]

def __len__(self):
    return len(self.data_list)

def generate_batch_list(self):
    training_data = self.vocab.get_training_phrases()
    for query in training_data.Query:
        query = utils.skip_gram_tokenize(vocab=self.vocab, sentence=query)
        for entry in query:
            self.data_list.append(entry)
    for response in training_data.Response:
        response = utils.skip_gram_tokenize(vocab=self.vocab, sentence=response)
        for entry in response:
            self.data_list.append(entry)

这是实际的数据加载器部分:

 dataset = Word2VecDataset(self.vocab)
 data_loader = torch_data.DataLoader(dataset, self.batch_size, True, num_workers=4)

 print('Model Initialized')
 for epo in range(self.num_epochs):
     loss_val = None
     for i_batch, sample_batched in enumerate(data_loader): # This seems to be causing issues. For some reason this is the part that 'reboots' the whole model, making it print twice (more info under the code)
         loss_val = 0
         for data, target in sample_batched:
         ....

现在,奇怪的是初始化阶段(您在这里没有看到)说“这是检测到的GPU:xxx”,而print('Model Initialized')被打印两次了。

最后,(pastebin)是完整的控制台日志(带有错误)。

1 个答案:

答案 0 :(得分:0)

我有同样的问题。我使用python代码中的if __name__ == '__main__':解决了这个问题,但是我得以解决jupyter笔记本中折断的管道...