单例数组array(<torch.utils.data.dataloader.dataloader object =“” at =“” 0x7f47667cda20 =“”>,dtype = object)无法视为有效集合

时间:2019-03-21 17:17:18

标签: machine-learning scikit-learn roc

我正在尝试使用Scikit-learn为我的4类分类集打印多类ROC曲线。我在Google Colab中使用它。我一直在关注this tutorial from Scikit。当我尝试执行此操作时-

# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

我知道我需要对输出/测试数据进行二值化。所以,我做到了-

from google.colab import drive
drive.mount('/content/drive')
data = "/content/drive/My Drive/AMD_new"


import torch
import helper
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
import torchvision.models as models

from torchvision import datasets ,transforms

#Changning the transform of the data-
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),
                                     transforms.RandomResizedCrop(224),
                                     # transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                     ])

transform_test = transforms.Compose([transforms.RandomHorizontalFlip(),
                                     transforms.RandomResizedCrop(224),
                                     # transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                     ])

# choose the training and test datasets
train_data = datasets.ImageFolder(data+"/train", transform=transform_train)
test_data = datasets.ImageFolder(data+"/val", transform = transform_test)

dataloader_train = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)
dataloader_test = torch.utils.data.DataLoader(test_data, batch_size=32, num_workers=2)

# Binarize the output
dataloader_test = label_binarize(dataloader_test, classes=[0, 1, 2, 3])
nb_classes = dataloader_test.shape[1]

但是有一个错误,我无法解决。

E: Package 'python-software-properties' has no installation candidate
··········
fuse: mountpoint is not empty
fuse: if you are sure this is safe, use the 'nonempty' mount option
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-11-f3ce2af1dde4> in <module>()
     56 dataloader_test = torch.utils.data.DataLoader(test_data, batch_size=32, num_workers=2)
     57 # Binarize the output
---> 58 dataloader_test = label_binarize(dataloader_test, classes=[0, 1, 2, 3])
     59 nb_classes = dataloader_test.shape[1]

/usr/local/lib/python3.6/dist-packages/sklearn/preprocessing/label.py in label_binarize(y, classes, neg_label, pos_label, sparse_output)
    579         # XXX Workaround that will be removed when list of list format is
    580         # dropped
--> 581         y = check_array(y, accept_sparse='csr', ensure_2d=False, dtype=None)
    582     else:
    583         if _num_samples(y) == 0:

/usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, warn_on_dtype, estimator)
    575     shape_repr = _shape_repr(array.shape)
    576     if ensure_min_samples > 0:
--> 577         n_samples = _num_samples(array)
    578         if n_samples < ensure_min_samples:
    579             raise ValueError("Found array with %d sample(s) (shape=%s) while a"

/usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py in _num_samples(x)
    140         if len(x.shape) == 0:
    141             raise TypeError("Singleton array %r cannot be considered"
--> 142                             " a valid collection." % x)
    143         # Check that shape is returning an integer or default to len
    144         # Dask dataframes may not return numeric shape[0] value

TypeError: Singleton array array(<torch.utils.data.dataloader.DataLoader object at 0x7f47667cda20>,
      dtype=object) cannot be considered a valid collection.

请帮助我克服这个问题。我已关注此帖子-Singleton array array(<function train at 0x7f3a311320d0>, dtype=object) cannot be considered a valid collection 谢谢。

0 个答案:

没有答案