我想用MAML梯度更新来训练我的转换器模型。我发现更高的解决方案是一种有效的方法。我也愿意接受其他方法。
因此,当meta_dataset [0]被调用时,我的MetaDataset会包装任何GlueDataset来给出包含所有类的列表。因此,这将成为num_of_classes(N)种方式的K射击示例。
我已经写了这篇,扩展了MAML的HF Trainer。
def train(self):
self.create_optimizer_and_scheduler(
int(
len(self.train_dataloader)
// self.args.gradient_accumulation_steps
* self.args.num_train_epochs
)
)
logger.info("***** Running training *****")
self.global_step = 0
self.epoch = 0
eval_step = [2 ** i for i in range(1, 20)]
inner_optimizer = torch.optim.SGD(
self.model.parameters(), lr=self.args.step_size
)
self.model.train()
tqdm_iterator = tqdm(self.train_dataloader, desc="Batch Index")
# n_inner_iter = 5
self.optimizer.zero_grad()
query_dataloader = iter(self.train_dataloader)
for batch_idx, meta_batch in enumerate(tqdm_iterator):
target_batch = next(query_dataloader)
outer_loss = 0.0
# Loop through all classes
for inputs, target_inputs in zip(meta_batch, target_batch):
for k, v in inputs.items():
inputs[k] = v.to(self.args.device)
target_inputs[k] = v.to(self.args.device)
with higher.innerloop_ctx(
self.model, inner_optimizer, copy_initial_weights=False
) as (fmodel, diffopt):
inner_loss = fmodel(**inputs)[0]
diffopt.step(inner_loss)
outer_loss += fmodel(**target_inputs)[0]
self.global_step += 1
self.optimizer.step()
outer_loss.backward()
if (batch_idx + 1) % self.args.gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.args.max_grad_norm
)
# Run evaluation on task list
if self.global_step in eval_step:
output = self.prediction_loop(self.eval_dataloader, description = "Evaluation")
self.log(output.metrics)
output_dir = os.path.join(
self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}",
)
self.save_model(output_dir)
上述代码似乎无法正常工作,因为准确性没有提高。任何方向/提示,将不胜感激。