获取pytorch数据集的子集

时间:2017-11-22 10:22:37

标签: python machine-learning neural-network torch pytorch

我有一个网络,我想在一些数据集上训练(例如,说CIFAR10)。我可以通过

创建数据加载器对象
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

我的问题如下:假设我想进行几次不同的训练迭代。让我们说我首先想要在奇数位置的所有图像上训练网络,然后在偶数位置的所有图像上训练网络,依此类推。为此,我需要能够访问这些图像。不幸的是,trainset似乎不允许此类访问。也就是说,尝试执行trainset[:1000]或更多trainset[mask]会导致错误。

我可以改为

trainset.train_data=trainset.train_data[mask]
trainset.train_labels=trainset.train_labels[mask]

然后

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              shuffle=True, num_workers=2)

然而,这将迫使我在每次迭代中创建完整数据集的新副本(因为我已经更改trainset.train_data所以我需要重新定义trainset)。有没有办法避免它?

理想情况下,我想要一些东西"相当于"到

trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,
                                              shuffle=True, num_workers=2)

2 个答案:

答案 0 :(得分:11)

您可以为数据集加载器定义自定义采样器,避免重新创建数据集(只需为每个不同的采样创建一个新的加载器)。

Option Explicit

Sub ProcessAllSlides()

Dim sld As Slide
Dim Shp As Shape
Dim oCht As Chart
Dim i As Long
Dim ChartIndex As Long

' set the Active Slide
Set sld = Application.ActiveWindow.View.Slide

ChartIndex = 1

' --- loop through the Slide shapes and search for the Shape of type chart
For i = 1 To sld.Shapes.Count
    If sld.Shapes(i).HasChart = msoTrue Then  ' if current shape is a chart
        Set Shp = sld.Shapes(i)
        Set oCht = Shp.Chart

        If ChartIndex = 1 Then ' first chart
            SetChartSizeAndPosition Shp, 30, 120, 320, 240
            ChartIndex = ChartIndex + 1

        ElseIf ChartIndex = 2 Then ' second chart
            SetChartSizeAndPosition Shp, 370, 120, 320, 240
        End If

        With oCht.PlotArea
            ' Edit these values as needed
            ' Change the following lines to e.g. Msgbox .Left etc
            ' to get the values of the chart you want to match others TO
            .Left = 0
            .Top = 0
            .Height = 220
            .Width = 300
        End With

        Set oCht = Nothing
        Set Shp = Nothing
    End If
Next i

End Sub

PS:您可以在此处找到更多信息:http://pytorch.org/docs/master/_modules/torch/utils/data/sampler.html#Sampler

答案 1 :(得分:8)

torch.utils.data.Subset更容易,支持shuffle,并且不需要编写自己的采样器:

import torchvision
import torch

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=None)

evens = list(range(0, len(trainset), 2))
odds = list(range(1, len(trainset), 2))
trainset_1 = torch.utils.data.Subset(trainset, evens)
trainset_2 = torch.utils.data.Subset(trainset, odds)

trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=4,
                                            shuffle=True, num_workers=2)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=4,
                                            shuffle=True, num_workers=2)