torch.utils.data.random_split()未拆分数据

时间:2019-09-04 13:37:17

标签: deep-learning pytorch

使用torch.utils.data.random_split时我不会分裂。

我得到train_sizeval_size的正确数字,但是当我执行random_split时,train_dataval_data都得到full_data。没有分裂发生。

请帮助我解决这个问题。

class DeviceLoader(Dataset):

def __init__(self, root_dir, train=True, transform=None):
    self.file_path = root_dir
    self.train = train
    self.transform = transform
    self.file_names = ['%s/%s'%(root,file) for root,_,files in os.walk(root_dir) for file in files]
    self.len = len(self.file_names)
    self.labels = {'BP_Raw_Images':0, 'DT_Raw_Images':1, 'GL_Raw_Images':2, 'PO_Raw_Images':3, 'WS_Raw_Images':4}

def __len__(self):
    return(len(self.file_names))

def __getitem__(self, idx):
    file_name = self.file_names[idx]
    device = file_name.split('/')[5]
    img = self.pil_loader(file_name)
    if(self.transform):
        img = self.transform(img)
    cat = self.labels[device]            
    if(self.train):
        return(img, cat)
    else:
        return(img, file_name)
full_data = DeviceLoader(root_dir=’/kaggle/input/devices/dataset/’, transform=transforms, train=True)
train_size = int(0.7*len(full_data))
val_size = len(full_data) - train_size
train_data, val_data = torch.utils.data.random_split(full_data,[train_size,val_size])

预期结果是将full_data分为train_data(2000)和val_data(500)。但是相反,我在火车和火车上都得到了full_data(2500)。

1 个答案:

答案 0 :(得分:1)

从下面的图像中您可以看到,它实际上构成了数据的子集,但原始数据集仍然存在。这可能会造成混淆。我对mnist数据集进行了以下操作

    train, validate, test = data.random_split(training_set, [50000, 10000, 10000])
    print(len(train))
    print(len(validate))
    print(len(test))

输出:

50000
10000
10000

enter image description here