如何显示分类错误的实例的文件路径?

时间:2020-08-30 14:42:20

标签: python compiler-errors pytorch

我正在尝试显示模型的错误分类实例的文件路径(使用Pytorch),并编写了以下代码:

incorrect_examples = []

model_gpu.eval()
for data in test_loader:
   images, labels = data
   images = images.to(device) 
   labels = labels.to(device)
   output = model_gpu(images)
   _, pred = torch.max(output,1)
   idxs_mask = ((pred == labels) == False).nonzero()
   incorrect_examples.append(images[idxs_mask].cpu().numpy())
   print(images[idxs_mask].cpu().numpy())
   print(test_loader.dataset.df.path[idxs_mask])

print(incorrect_examples)
print(len(incorrect_examples))

但是,我收到以下错误:

KeyError                                  Traceback (most recent call last)
<ipython-input-45-02179a61953e> in <module>()
     11    incorrect_examples.append(images[idxs_mask].cpu().numpy())
     12    print(images[idxs_mask].cpu().numpy())
---> 13    print(test_loader.dataset.df.path[idxs_mask])
     14 
     15 print(incorrect_examples)

1 frames
/usr/local/lib/python3.6/dist-packages/pandas/core/series.py in __getitem__(self, key)
    869         key = com.apply_if_callable(key, self)
    870         try:
--> 871             result = self.index.get_value(self, key)
    872 
    873             if not is_scalar(result):

/usr/local/lib/python3.6/dist-packages/pandas/core/indexes/base.py in get_value(self, series, key)
   4403         k = self._convert_scalar_indexer(k, kind="getitem")
   4404         try:
-> 4405             return self._engine.get_value(s, k, tz=getattr(series.dtype, "tz", None))
   4406         except KeyError as e1:
   4407             if len(self) > 0 and (self.holds_integer() or self.is_boolean()):

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_value()

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_value()

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/index_class_helper.pxi in pandas._libs.index.Int64Engine._check_type()

KeyError: tensor([[ 1],
        [ 2],
        [ 5],
        [12],
        [14],
        [15]], device='cuda:0')

谁能告诉我我要去哪里错了?预先感谢!

0 个答案:

没有答案