使用torch.utils.data.random_split
时我不会分裂。
我得到train_size
和val_size
的正确数字,但是当我执行random_split
时,train_data
和val_data
都得到full_data
。没有分裂发生。
请帮助我解决这个问题。
class DeviceLoader(Dataset):
def __init__(self, root_dir, train=True, transform=None):
self.file_path = root_dir
self.train = train
self.transform = transform
self.file_names = ['%s/%s'%(root,file) for root,_,files in os.walk(root_dir) for file in files]
self.len = len(self.file_names)
self.labels = {'BP_Raw_Images':0, 'DT_Raw_Images':1, 'GL_Raw_Images':2, 'PO_Raw_Images':3, 'WS_Raw_Images':4}
def __len__(self):
return(len(self.file_names))
def __getitem__(self, idx):
file_name = self.file_names[idx]
device = file_name.split('/')[5]
img = self.pil_loader(file_name)
if(self.transform):
img = self.transform(img)
cat = self.labels[device]
if(self.train):
return(img, cat)
else:
return(img, file_name)
full_data = DeviceLoader(root_dir=’/kaggle/input/devices/dataset/’, transform=transforms, train=True)
train_size = int(0.7*len(full_data))
val_size = len(full_data) - train_size
train_data, val_data = torch.utils.data.random_split(full_data,[train_size,val_size])
预期结果是将full_data
分为train_data
(2000)和val_data
(500)。但是相反,我在火车和火车上都得到了full_data
(2500)。