split = int(len(train_dataset) * 0.8)
print(split)
index_list = list(range(len(train_dataset)))
train_idx, valid_idx = index_list[:split], index_list[split:]
print(len(train_idx),len(valid_idx))
48000 12000
我获得了train_idx和valid_idx的48000和12000索引
然后我将此编号应用于数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler= SubsetRandomSampler(train_idx))
valid_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=SubsetRandomSampler(valid_idx))
print(len(train_loader.dataset),len(valid_loader.dataset))
60000 60000
但是len看起来不正确
for epoch in range(EPOCHS):
for i , (train_idx, valid_idx) in enumerate(splits):
## create iterator objects for train and valid datasets
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler= SubsetRandomSampler(train_idx))
valid_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=SubsetRandomSampler(valid_idx))
submit_loader = DataLoader(dataset = test_dataset,batch_size = batch_size, shuffle = True)
train_loss, train_acc = train(model, device, train_loader, optimizer, criterion)
valid_loss, valid_acc = evaluate(model, device, valid_loader, criterion)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f'| Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc*100:05.2f}% | Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:05.2f}% |')