traindir = os.path.join(args.data, "train")
valdir = os.path.join(args.data, "val")
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(root=traindir, transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
val_dataset = datasets.ImageFolder(root=valdir, transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
def worker_init_fn(worker_id):
# seed=10
# seed+=worker_id
# np.random.seed(seed)
# print(worker_id)
print(np.random.get_state()[1][0])
print(np.random.get_state()[1][0]+worker_id)
np.random.seed(np.random.get_state()[1][0]+worker_id)
def _init_fn(worker_id):
np.random.seed(42)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
train_loader2 = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
train_loader3 = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
dataset=val_dataset,
batch_size=args.batch_size, shuffle=False,
# batch_size=len(val_dataset), shuffle=False,
num_workers=args.workers, pin_memory=True)
val_loader2 = torch.utils.data.DataLoader(
dataset=val_dataset,
batch_size=args.batch_size, shuffle=False,
# batch_size=len(val_dataset), shuffle=False,
num_workers=args.workers, pin_memory=True)
val_loader3 = torch.utils.data.DataLoader(
dataset=val_dataset,
batch_size=args.batch_size, shuffle=False,
# batch_size=len(val_dataset), shuffle=False,
num_workers=args.workers, pin_memory=True)
for i, dd in enumerate(zip(train_loader,train_loader2,train_loader3)):
input1,target1=dd[0]
input2,target2=dd[1]
input3,target3=dd[2]
print("1",target1)
print("2",target2)
print("3",target3)
有一个要求,应将同一数据标签的三份副本并行洗净,并读入三个不同的网络中以进行联合训练。但是,如果将shuffle设置为True,则Dateloader的三个副本将为每个批次产生不同的标签。因此,我们不能一起训练