DataLoader使用pytorch创建数据集

时间:2020-08-05 15:49:38

标签: python pytorch

我有一个带子文件夹(类)的文件夹,每个子文件夹中都有图像。

<!doctype html>
<html lang="en">
  <head>
    <!-- Required meta tags -->
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">

    <!-- Bootstrap CSS -->
    <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.5.1/css/bootstrap.min.css" integrity="sha384-VCmXjywReHh4PwowAiWNagnWcLhlEJLA5buUprzK8rxFgeH0kww/aWY76TfkUoSX" crossorigin="anonymous">

    <title>Hello, world!</title>
  </head>
  <body>
    <h1>The Properties are:-</h1>
        
    {% for test_model in all_testing_mdoels %}
        <h3>{{test_model.prop}}</h3>
    {% endfor %}

    <!-- Optional JavaScript -->
    <!-- jQuery first, then Popper.js, then Bootstrap JS -->
    <script src="https://code.jquery.com/jquery-3.5.1.slim.min.js" integrity="sha384-DfXdz2htPH0lsSSs5nCTpuj/zy4C+OGpamoFVy38MVBnE+IbbVYUew+OrCXaRkfj" crossorigin="anonymous"></script>
    <script src="https://cdn.jsdelivr.net/npm/popper.js@1.16.1/dist/umd/popper.min.js" integrity="sha384-9/reFTGAW83EW2RDu2S0VKaIzap3H66lZH81PoYlFhbGU+6BZp6G7niu735Sk7lN" crossorigin="anonymous"></script>
    <script src="https://stackpath.bootstrapcdn.com/bootstrap/4.5.1/js/bootstrap.min.js" integrity="sha384-XEerZL0cuoUbHE4nZReLT7nx9gQrQreJekYhJD9WNWhH8nEW+0c5qq7aIo2Wl30J" crossorigin="anonymous"></script>
  </body>
</html>

我的目标是创建一个数据集(训练+测试集)以使用pytorch resnet训练我的模型。 我有一个错误,我不知道如何解决它,因为我不太了解DataLoader结构,所以我尝试了以下方法:

我有这个:

data
  |_ classe1
        |_ image1
        |_ image2
  |_ classe2
        |_ ...

但是当我尝试运行模型时,出现此错误:

dataset = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['data']}

batch_size = 32
validation_split = .3
shuffle_dataset = True
random_seed= 42

# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                sampler=valid_sampler)

dataloaders_dict = {'train': train_loader, 'val': validation_loader}

有什么建议吗?是否检测到任何错误?

1 个答案:

答案 0 :(得分:0)

问题很可能来自您的第一行,其中您的dataset实际上是包含一个元素(pytorch数据集)的字典。这样会更好:

x = 'data'
dataset = datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])

我假设data_transforms['data']是预期类型的​​转换(详细介绍here)。

当pytorch尝试从仅包含一个元素的“数据集”(字典)获取张量时,可能会产生keyerror。

顺便说一句,我认为pytorch提供了torch.utils.data.random_split`功能,因此您不必自己进行训练/测试拆分。您可能要查找它。