将两个Numpy数据集转换为特定的PyTorch数据集

时间:2020-05-13 15:01:15

标签: python neural-network pytorch

我想玩一个能识别手写数字的神经网络。我在使用PyTorch的网络上找到了其中一些,但是它们似乎是以特定格式从MNIST网站下载数据的。但是,我的数据如下:

with np.load('prediction-challenge-01-data.npz') as fh:
     data_x = fh['data_x']
     data_y = fh['data_y']

data_x是训练数据,data_y是图片的标签。我希望这些数据集与trainloader的格式相同,如下所示:

trainset = datasets.MNIST('/data/mnist', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

trainloader已有训练集data_x并将标签data_y组合在一起的地方。

有没有办法做到这一点?

编辑 data_xdata_y的形状:

In [1]:  data_x.shape
Out[2]: (20000, 1, 28, 28)

In [5]:  data_y.shape
Out[7]: (20000,)

1 个答案:

答案 0 :(得分:1)

您可以轻松创建自己的数据集。只需继承 class MyApplication : Application() { override fun onCreate() { super.onCreate() val nightModeEnabled = //get value from shared prefs or wherever you are storing this flag if (nightModeEnabled) { Timber.d("Manually instantiating WebView to avoid night mode issue."); try { WebView(applicationContext) } catch (e: Exception) { Timber.e("Got exception while trying to instantiate WebView to avoid night mode issue. Ignoring problem.", e) } AppCompatDelegate.setDefaultNightMode(AppCompatDelegate.MODE_NIGHT_YES) } } } 并实施 至少torch.utils.data.Dataset
这是一个快速且肮脏的示例,可帮助您入门:

__getitem__

,您可以像这样创建数据集:

class YourOwnDataset(torch.utils.data.Dataset):
    def __init__(self, input_file_path, transformations) :
        super().__init__()
        self.path = input_file_path
        self.transforms = transformations

        with np.load(self.path) as fh:
            # I assume fh['data_x'] is a list you get the idea  
            self.data = fh['data_x']
            self.labels = fh['data_y']


    # in getitem, we retrieve one item based on the input index
    def __getitem__(self, index):
        data = self.data[index]
        # based on the loss you chose and what you have in mind, 
        # you can transform you label, here I assume they are 
        # integer numbers (like, 1, 3, etc as labels used for classification)
        label = self.labels[index]
        img = convert/reshape your data into img
        img = self.transforms(img)
        return img, labels

    def __len__(self):
        return len(self.data)