我正在努力学习《Deep Learning with PyTorch》这本书。我正在使用新的 R 包 torch
和 torchvision
。
在第 173 页的 7.2.1 节中,我只是不确定如何过滤此数据集以仅包含标签 1 和 3(对应于书中的 0 和 2)。
这是我的代码,我想知道如何按照书中的代码过滤 transformed_cifar10
。含义对其进行过滤,使 transformed_cifar10$y
标签仅包含 1 和 3。然后将 {1,3} 重新映射到 {1,2}。
library(dplyr)
library(torch)
library(torchvision)
data_path <- "./ch7/data" # need to change this?
train_transforms <- function (img) {
img %>%
transform_to_tensor() %>%
transform_normalize(mean = c(0.4915, 0.4823, 0.4468),
std = c(0.2470, 0.2435, 0.2616))
}
transformed_cifar10 <- cifar10_dataset(data_path,
train = TRUE,
download = TRUE,
transform = train_transforms)
这是书中的python代码:
# In[5]:
label_map = {0: 0, 2: 1}
class_names = ['airplane', 'bird']
cifar2 = [(img, label_map[label])
for img, label in cifar10
if label in [0, 2]]
起初我想尝试这样的事情,但显然它不起作用......
tensor_cifar10[tensor_cifar10$y == 1]