我正在用pytorch在python中开发一个神经网络,以便对图像对进行分类。因此,我想返回2张图像和地面真实情况作为输出,但是每当尝试使用数据加载器时,都会收到错误消息“ Broken pipe”。
我希望它作为分类器工作,所以我以this link (pytorch classifier CIFAR10) 为例
这是我的代码:
###Defining class
class continuousImgDataset(Dataset):
def __init__(self, tabDataset, transform=None):
"""
Args:
tabDataset : Contains the imported data like : Item 1 = [im1,im2,ground-truth]
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.tabPairImg = tabDataset
self.transform = transform
def __len__(self):
return len(self.tabPairImg)
def __getitem__(self, idx):
img1 = self.tabPairImg[idx][0]
img2 = self.tabPairImg[idx][1]
label = self.tabPairImg[idx][2]
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
return img1, img2,label
transform = transforms.Compose(
[transforms.Scale((32,32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
###Make the training and testing dataset
custom_dataset = continuousImgDataset(pairImg,transform)
train_dataset = continuousImgDataset(pairImg[: int(len(custom_dataset) * .90)],transform)
test_dataset = continuousImgDataset(pairImg[int(len(custom_dataset) * .90) : int(len(custom_dataset))],transform)
print(len(custom_dataset.tabPairImg)) #Output :22620
print(len(train_dataset.tabPairImg)) #Output : 20358
print(len(test_dataset.tabPairImg)) #Output : 2262
### Loaders
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=4,
shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('joint','disjoint')
### Image show and Error
"""From the official website but a little modified to have 2 imgs"""
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
dataiter = iter(trainloader) #ERROR HERE : Broken pipe
images1,images2, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images1))
imshow(torchvision.utils.make_grid(images2))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(2)))
我可能不完全了解数据加载器要对其进行迭代的期望。任何帮助,将不胜感激:)
编辑:这是完整的错误:
BrokenPipeError Traceback (most recent call last)
<ipython-input-58-314e85cc8fbb> in <module>
7
8 # get some random training images
----> 9 dataiter = iter(trainloader)
10 images1,images2, labels = dataiter.next()
11
C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py in __iter__(self)
191
192 def __iter__(self):
--> 193 return _DataLoaderIter(self)
194
195 def __len__(self):
C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py in __init__(self, loader)
467 # before it starts, and __del__ tries to join but will get:
468 # AssertionError: can only join a started process.
--> 469 w.start()
470 self.index_queues.append(index_queue)
471 self.workers.append(w)
C:\ProgramData\Anaconda3\lib\multiprocessing\process.py in start(self)
110 'daemonic processes are not allowed to have children'
111 _cleanup()
--> 112 self._popen = self._Popen(self)
113 self._sentinel = self._popen.sentinel
114 # Avoid a refcycle if the target function holds an indirect
C:\ProgramData\Anaconda3\lib\multiprocessing\context.py in _Popen(process_obj)
221 @staticmethod
222 def _Popen(process_obj):
--> 223 return _default_context.get_context().Process._Popen(process_obj)
224
225 class DefaultContext(BaseContext):
C:\ProgramData\Anaconda3\lib\multiprocessing\context.py in _Popen(process_obj)
320 def _Popen(process_obj):
321 from .popen_spawn_win32 import Popen
--> 322 return Popen(process_obj)
323
324 class SpawnContext(BaseContext):
C:\ProgramData\Anaconda3\lib\multiprocessing\popen_spawn_win32.py in __init__(self, process_obj)
87 try:
88 reduction.dump(prep_data, to_child)
---> 89 reduction.dump(process_obj, to_child)
90 finally:
91 set_spawning_popen(None)
C:\ProgramData\Anaconda3\lib\multiprocessing\reduction.py in dump(obj, file, protocol)
58 def dump(obj, file, protocol=None):
59 '''Replacement for pickle.dump() using ForkingPickler.'''
---> 60 ForkingPickler(file, protocol).dump(obj)
61
62 #
BrokenPipeError: [Errno 32] Broken pipe