我有一个也许很小的问题,但现在我被困了很长时间。希望有人可以帮助我。我目前在一个Kddcup99数据集上,我想通过DeepLearning(CNN网络)进行训练
我有一个“数据集”类,其中包括熊猫数据框。因此,我分解为普通数据并验证了数据集。到目前为止,没有问题。 我将其加载到一个Numpy向量中,将其焊到Tensor,然后将其定向到DataLoader。
数据集类具有以下两个重要的类,可以进行迭代:
def __len__(self):
return len(self.val_df)
def __getitem__(self, index):
img, target = self.val_df[index][:-1], self.val_df[index][-1]
return img, target, index
该类中没有DataLoader字符串:
test_dataloader = DataLoader(datat.val_df, batch_size=10, shuffle=True)
在我的Trainer类中,我有一个for循环,该循环应遍历Dataloader:
with torch.no_grad():
for data in dataloader:
inputs, labels, idx = data
inputs = inputs.to(self.device)
但是不会。我无法访问标签,索引等。
我的问题是现在:为什么? 如何通过数据加载器从给定的数据集中访问标签,索引?
谢谢大家的帮助! 非常感谢。
答案 0 :(得分:0)
DataLoader
的第一个参数是您要从中加载数据的数据集,通常是Dataset
,但并不限于Dataset
的任何实例。只要可以定义长度(__len__
)并且可以索引(__getitem__
允许),就可以接受。
您正在将datat.val_df
传递到DataLoader
,这大概是一个NumPy数组。 NumPy数组具有长度并且可以被索引,因此可以在DataLoader
中使用。由于您直接传递该数组,因此永远不会调用数据集的__getitem__
,但会对数组本身进行索引,因此每个项目都只是data.val_df[index]
。
您必须使用数据集本身(DataLoader
来代替datat
使用基础数据:
test_dataloader = DataLoader(datat, batch_size=10, shuffle=True)