在pytorch

时间:2019-05-21 12:51:05

标签: python python-3.x pytorch

通常,当我们在pytorch中加载数据时,我们会执行以下操作

for x, y in dataloaders:
    # Do something

但是,在名为MusicNet的数据集中,他们这样声明了自己的数据集和数据加载器

train_set = musicnet.MusicNet(root=root, train=True, download=True, window=window)#, pitch_shift=5, jitter=.1)
test_set = musicnet.MusicNet(root=root, train=False, window=window, epoch_size=50000)

train_loader = torch.utils.data.DataLoader(dataset=train_set,batch_size=batch_size,**kwargs)
test_loader = torch.utils.data.DataLoader(dataset=test_set,batch_size=batch_size,**kwargs)

然后他们像这样加载数据

with train_set, test_set:
    for i, (x, y) in enumerate(train_loader):
        # Do something

问题1

我不明白为什么没有第with train_set, test_set行就无法使用代码。

问题2

此外,我该如何访问数据?

我尝试了

train_set.access(2560,0)

with train_set, test_set:
    x, y = train_set.access(2560,0)

他们要么给我一条错误消息,如

  

KeyError跟踪(最近一次通话最近)   ----> 1 train_set.access(2560,0)

     

/workspace/raven_data/AMT/MusicNet/pytorch_musicnet/musicnet.py在   如果self.mmap访问(self,rec_id,s,shift,jitter)106107:   -> 108 x = np.frombuffer(self.records [rec_id] [0] [ssz_float:int(s + scaleself.window)* sz_float],   dtype = np.float32).copy()109其他:110 fid,_ = self.records [rec_id]

     

KeyError:2560

或者给我一个空的xy

1 个答案:

答案 0 :(得分:1)

  

问题1

     

我不明白为什么没有第with train_set, test_set行就无法使用代码。

要使torch.utils.data.DataLoader自定义数据集设计一起使用,您必须创建一个数据集的类,该子类继承{{3} }(并实现特定功能),并将其传递给数据加载器,即使他们这样说:

  

所有其他数据集应将其子类化。所有子类都应覆盖__len__(提供数据集的大小)和__getitem__(支持从0到len(self)不包括在内的整数索引)。

这是发生在以下地方:

train_set = musicnet.MusicNet(root=root, train=True, download=True, window=window)#, pitch_shift=5, jitter=.1)

test_set = musicnet.MusicNet(root=root, train=False, window=window, epoch_size=50000)

train_loader = torch.utils.data.DataLoader(dataset=train_set,batch_size=batch_size,**kwargs)
test_loader = torch.utils.data.DataLoader(dataset=test_set,batch_size=batch_size,**k

如果检查他们的torch.utils.data.Dataset,就会发现他们这样做了。

  

问题2

     

另外,我该如何访问数据?

有可能的方法:

要从数据集中仅批量获取,您可以执行以下操作:

batch = next(iter(train_loader))

要访问整个数据集(尤其是在您的示例中)

dataset = train_loader.dataset.records

.records是一部分,可能因数据集而异,我说.records,因为这是我在musicnet.MusicNet中发现的内容)