我在尝试使用pytorch ignite运行自定义验证过程时遇到问题
我的验证不需要消耗ignite的数据加载器或使用引擎, tt只是一个python函数,可以运行并将内容保存到dict和磁盘中,而其他代码则可以管理数据流。
在具有分布式训练的较大数据集上,它调用run_eval_
一次,然后冻结。
即使我仅在评估模式下使用数据加载器,如果我进行分布式培训,是否所有数据加载器都需要使用DistributedSampler
?
def train_loop(args, path):
model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)
def update(engine, batch):
model.train()
batch = to_gpu(batch)
#batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
outputs = model(*batch)
loss = loss_fn(*outputs)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss.item()
trainer = Engine(update)
metrics = []
eval_data = EvalData(eval_path)
def _run_eval():
metrics.append(eval_data.score(model))
print('RAN EVAL')
pickle_save(metrics.pkl) # This gets hit once then never again
return
run_eval = lambda _: run_eval()
if args.distributed:
trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: d.train_sampler.set_epoch(engine.state.epoch))
trainer.add_event_handler(Events.EPOCH_COMPLETED, run_eval)
trainer.run(train_dl, max_epochs=5)