如何以相同顺序重新排列不同的日期加载器?

时间:2020-11-02 03:21:42

标签: python pytorch shuffle dataloader

traindir = os.path.join(args.data, "train")
valdir = os.path.join(args.data, "val")
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_dataset = datasets.ImageFolder(root=traindir, transform=transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
]))
val_dataset = datasets.ImageFolder(root=valdir, transform=transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
]))


def worker_init_fn(worker_id):
    # seed=10
    # seed+=worker_id
    # np.random.seed(seed)
    # print(worker_id)
    print(np.random.get_state()[1][0])
    print(np.random.get_state()[1][0]+worker_id)
    np.random.seed(np.random.get_state()[1][0]+worker_id)

def _init_fn(worker_id):
    np.random.seed(42)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True)
train_loader2 = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True)
train_loader3 = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
    dataset=val_dataset,
    batch_size=args.batch_size, shuffle=False,
    # batch_size=len(val_dataset), shuffle=False,
    num_workers=args.workers, pin_memory=True)
val_loader2 = torch.utils.data.DataLoader(
    dataset=val_dataset,
    batch_size=args.batch_size, shuffle=False,
    # batch_size=len(val_dataset), shuffle=False,
    num_workers=args.workers, pin_memory=True)
val_loader3 = torch.utils.data.DataLoader(
    dataset=val_dataset,
    batch_size=args.batch_size, shuffle=False,
    # batch_size=len(val_dataset), shuffle=False,
    num_workers=args.workers, pin_memory=True)


for i, dd in enumerate(zip(train_loader,train_loader2,train_loader3)):

    input1,target1=dd[0]
    input2,target2=dd[1]
    input3,target3=dd[2]

    print("1",target1)
    print("2",target2)
    print("3",target3)

有一个要求,应将同一数据标签的三份副本并行洗净,并读入三个不同的网络中以进行联合训练。但是,如果将shuffle设置为True,则Dateloader的三个副本将为每个批次产生不同的标签。因此,我们不能一起训练

0 个答案:

没有答案