使用大于0的num_worker时出现Pytorch DataLoader错误

时间:2019-03-31 02:44:55

标签: python parallel-processing pytorch

无法使用将num_worker设置为大于零的Dataloader。

我的割炬版本是1.0.0。

我的代码:

class InputData(Dataset):
    '''read data'''
    def __init__(self,train_serial_nums
                im_dir = TRAIN_PATH+'/img/', mask_dir = TRAIN_PATH+'/label/'):
        self.train_serial_nums = train_serial_nums
        self.im_dir = im_dir
        self.mask_dir = mask_dir

    def open_as_gray(self,serial_num):
        # open an image and convert to gray
        serial_num = str(serial_num)
        im_name = "img_"+serial_num+".jpg"
        mask_name = "label_"+serial_num+".png"

        im = np.asarray(Image.open(os.path.join(self.im_dir, im_name)))
        im = im/255

        _mask = Image.open(os.path.join(self.mask_dir, mask_name)).convert("L")
        _mask = np.expand_dims(_mask, axis=-1)
        mask_0 = np.where(np.logical_or(_mask == 60, _mask == 180), 1, 0)
        mask_1 = np.where(np.logical_or(_mask == 120, _mask == 180), 1, 0)
        mask = np.concatenate([mask_0, mask_1], axis=-1)

        return im.transpose(2,0,1), mask.transpose(2,0,1), im_name, mask_name

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.train_serial_nums)

    def __getitem__(self, index):
        'Generates one sample of data'

        trainId = self.train_serial_nums[index]
        # Load data and get label
        trainImg, trainMask,_,_ = self.open_as_gray(trainId)
        return trainImg, trainMask

我使用DataLoader`进行并行计算。

train_serial_nums = np.arange(40000, 50000)
train_set = InputData(train_serial_nums)
    train_generator = DataLoader(
        train_set, batch_size=32, 
        num_workers=1)

错误信息:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<timed exec> in <module>

/data/anonym2/anaconda3/envs/tensorflow/lib/python3.6/site-packages/torch/utils/data/dataloader.py in __next__(self)
    629         while True:
    630             assert (not self.shutdown and self.batches_outstanding > 0)
--> 631             idx, batch = self._get_batch()
    632             self.batches_outstanding -= 1
    633             if idx != self.rcvd_idx:

/data/anonym2/anaconda3/envs/tensorflow/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _get_batch(self)
    608             # need to call `.task_done()` because we don't use `.join()`.
    609         else:
--> 610             return self.data_queue.get()
    611 
    612     def __next__(self):

/data/anonym2/anaconda3/envs/tensorflow/lib/python3.6/multiprocessing/queues.py in get(self, block, timeout)
     92         if block and timeout is None:
     93             with self._rlock:
---> 94                 res = self._recv_bytes()
     95             self._sem.release()
     96         else:

/data/anonym2/anaconda3/envs/tensorflow/lib/python3.6/multiprocessing/connection.py in recv_bytes(self, maxlength)
    214         if maxlength is not None and maxlength < 0:
    215             raise ValueError("negative maxlength")
--> 216         buf = self._recv_bytes(maxlength)
    217         if buf is None:
    218             self._bad_message_length()

/data/anonym2/anaconda3/envs/tensorflow/lib/python3.6/multiprocessing/connection.py in _recv_bytes(self, maxsize)
    405 
    406     def _recv_bytes(self, maxsize=None):
--> 407         buf = self._recv(4)
    408         size, = struct.unpack("!i", buf.getvalue())
    409         if maxsize is not None and size > maxsize:

/data/anonym2/anaconda3/envs/tensorflow/lib/python3.6/multiprocessing/connection.py in _recv(self, size, read)
    377         remaining = size
    378         while remaining > 0:
--> 379             chunk = read(handle, remaining)
    380             n = len(chunk)
    381             if n == 0:

/data/anonym2/anaconda3/envs/tensorflow/lib/python3.6/site-packages/torch/utils/data/dataloader.py in handler(signum, frame)
    272         # This following call uses `waitid` with WNOHANG from C side. Therefore,
    273         # Python can still get and update the process status successfully.
--> 274         _error_if_any_worker_fails()
    275         if previous_handler is not None:
    276             previous_handler(signum, frame)

RuntimeError: DataLoader worker (pid 5235) exited unexpectedly with exit code 1. Details are lost due to multiprocessing. 
Rerunning with num_workers=0 may give better error trace.

如果我设置num_workers = 0,则代码运行良好但非常缓慢。 (不再删除更多的细节ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss,

0 个答案:

没有答案