在 R 中过滤火炬数据集

时间:2021-03-31 02:54:00

标签: r pytorch torch torchvision

我正在努力学习《Deep Learning with PyTorch》这本书。我正在使用新的 R 包 torchtorchvision

在第 173 页的 7.2.1 节中,我只是不确定如何过滤此数据集以仅包含标签 1 和 3(对应于书中的 0 和 2)。

这是我的代码,我想知道如何按照书中的代码过滤 transformed_cifar10。含义对其进行过滤,使 transformed_cifar10$y 标签仅包含 1 和 3。然后将 {1,3} 重新映射到 {1,2}。

library(dplyr)
library(torch)
library(torchvision)

data_path <- "./ch7/data" # need to change this?

train_transforms <- function (img) {
  img %>% 
    transform_to_tensor() %>% 
    transform_normalize(mean = c(0.4915, 0.4823, 0.4468),
                        std = c(0.2470, 0.2435, 0.2616))
}

transformed_cifar10 <- cifar10_dataset(data_path, 
                                       train = TRUE, 
                                       download = TRUE, 
                                       transform = train_transforms)

这是书中的python代码:

# In[5]:
label_map = {0: 0, 2: 1}
class_names = ['airplane', 'bird']
cifar2 = [(img, label_map[label])
          for img, label in cifar10
          if label in [0, 2]]

起初我想尝试这样的事情,但显然它不起作用......

tensor_cifar10[tensor_cifar10$y == 1]

0 个答案:

没有答案