Python数据集类+ PyTorch数据加载器:卡在__getitem__上,如何在测试期间获取索引,标签等?

时间:2020-05-18 11:40:55

标签: python machine-learning dataset pytorch dataloader

我有一个也许很小的问题,但现在我被困了很长时间。希望有人可以帮助我。我目前在一个Kddcup99数据集上,我想通过DeepLearning(CNN网络)进行训练

我有一个“数据集”类,其中包括熊猫数据框。因此,我分解为普通数据并验证了数据集。到目前为止,没有问题。 我将其加载到一个Numpy向量中,将其焊到Tensor,然后将其定向到DataLoader。

数据集类具有以下两个重要的类,可以进行迭代:

def __len__(self):
        return len(self.val_df)

def __getitem__(self, index):        
        img, target = self.val_df[index][:-1], self.val_df[index][-1]
        return img, target, index

该类中没有DataLoader字符串:

test_dataloader = DataLoader(datat.val_df, batch_size=10, shuffle=True)

在我的Trainer类中,我有一个for循环,该循环应遍历Dataloader:

with torch.no_grad():
            for data in dataloader:
                inputs, labels, idx = data
                inputs = inputs.to(self.device)

但是不会。我无法访问标签,索引等。

我的问题是现在:为什么? 如何通过数据加载器从给定的数据集中访问标签,索引?

谢谢大家的帮助! 非常感谢。

1 个答案:

答案 0 :(得分:0)

DataLoader的第一个参数是您要从中加载数据的数据集,通常是Dataset,但并不限于Dataset的任何实例。只要可以定义长度(__len__)并且可以索引(__getitem__允许),就可以接受。

您正在将datat.val_df传递到DataLoader,这大概是一个NumPy数组。 NumPy数组具有长度并且可以被索引,因此可以在DataLoader中使用。由于您直接传递该数组,因此永远不会调用数据集的__getitem__,但会对数组本身进行索引,因此每个项目都只是data.val_df[index]

您必须使用数据集本身(DataLoader来代替datat使用基础数据:

test_dataloader = DataLoader(datat, batch_size=10, shuffle=True)