如何制作类似于Torchvision数据集的自定义pytorch数据集?

时间:2020-07-27 23:50:29

标签: python pytorch torch

我是pytorch的新手,正在尝试重用Fashion MNIST CNN(from deeplizard)对我的时间序列数据进行分类。我发现很难理解数据集的结构,因为尽我所能地遵循this official tutorialthis SO question,我得到的东西太简单了。我认为这是因为我不太了解OOP。我制作的数据集在CNN上可以很好地进行训练,但随后尝试用卡住的代码分析结果。

因此,我从两个称为特征[4050、1、150、6]和目标[4050]的火炬张量创建数据集:

train_dataset = TensorDataset(features,targets) # create your datset
train_dataloader = DataLoader(train_dataset, batch_size=50, shuffle=False) # create your dataloader
print(train_dataset.__dict__.keys()) # list the attributes

我通过检查属性获得了打印输出

dict_keys(['tensors'])

但是在Fashion MNIST教程中,他们访问数据的方式如下:

train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=1000, shuffle=True)
print(train_set.__dict__.keys()) # list the attributes

然后通过检查属性获得此打印输出

dict_keys(['root','transform','target_transform','transforms, '火车','数据','目标'])

我的数据集可以很好地进行训练,但是当我进入本教程的后续分析部分时,他们希望我访问数据集的部分并且出现错误:

# Analytics
prediction_loader = torch.utils.data.DataLoader(train_dataset, batch_size=50)
train_preds = get_all_preds(network, prediction_loader)
preds_correct = train_preds.argmax(dim=1).eq(train_dataset.targets).sum().item()

print('total correct:', preds_correct)
print('accuracy:', preds_correct / len(train_set))

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-73-daa87335a92a> in <module>
      4 prediction_loader = torch.utils.data.DataLoader(train_dataset, batch_size=50)
      5 train_preds = get_all_preds(network, prediction_loader)
----> 6 preds_correct = train_preds.argmax(dim=1).eq(train_dataset.targets).sum().item()
      7 
      8 print('total correct:', preds_correct)

AttributeError: 'TensorDataset' object has no attribute 'targets'

谁能告诉我这是怎么回事?这是我需要更改数据集制作方式的东西,还是可以某种方式重写分析代码以访问数据集的正确部分?

1 个答案:

答案 0 :(得分:1)

.targets等效的TensorDatasettrain_dataset.tensors[1]

TensorDataset的实现非常简单:

class TensorDataset(Dataset[Tuple[Tensor, ...]]):
    r"""Dataset wrapping tensors.
    Each sample will be retrieved by indexing tensors along the first dimension.
    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    tensors: Tuple[Tensor, ...]

    def __init__(self, *tensors: Tensor) -> None:
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)
相关问题