变形金刚的元学习

时间:2020-08-19 14:05:07

标签: pytorch huggingface-transformers

我想用MAML梯度更新来训练我的转换器模型。我发现更高的解决方案是一种有效的方法。我也愿意接受其他方法。

因此,当meta_dataset [0]被调用时,我的MetaDataset会包装任何GlueDataset来给出包含所有类的列表。因此,这将成为num_of_classes(N)种方式的K射击示例。

我已经写了这篇,扩展了MAML的HF Trainer。

def train(self):

        self.create_optimizer_and_scheduler(
            int(
                len(self.train_dataloader)
                // self.args.gradient_accumulation_steps
                * self.args.num_train_epochs
            )
        )

        logger.info("***** Running training *****")

        self.global_step = 0
        self.epoch = 0

        eval_step = [2 ** i for i in range(1, 20)]
        inner_optimizer = torch.optim.SGD(
            self.model.parameters(), lr=self.args.step_size
        )
        self.model.train()

        tqdm_iterator = tqdm(self.train_dataloader, desc="Batch Index")

        #  n_inner_iter = 5
        self.optimizer.zero_grad()
        query_dataloader = iter(self.train_dataloader)

        for batch_idx, meta_batch in enumerate(tqdm_iterator):
            target_batch = next(query_dataloader)
            outer_loss = 0.0
            # Loop through all classes
            for inputs, target_inputs in zip(meta_batch, target_batch):

                for k, v in inputs.items():
                    inputs[k] = v.to(self.args.device)
                    target_inputs[k] = v.to(self.args.device)

                with higher.innerloop_ctx(
                    self.model, inner_optimizer, copy_initial_weights=False
                ) as (fmodel, diffopt):

                    inner_loss = fmodel(**inputs)[0]
                    diffopt.step(inner_loss)
                    outer_loss += fmodel(**target_inputs)[0]

            self.global_step += 1
            self.optimizer.step()

            outer_loss.backward()

            if (batch_idx + 1) % self.args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.args.max_grad_norm
                )

            # Run evaluation on task list
            if self.global_step in eval_step:
                output = self.prediction_loop(self.eval_dataloader, description = "Evaluation")
                self.log(output.metrics)

                output_dir = os.path.join(
                    self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}",
                )
                self.save_model(output_dir)

上述代码似乎无法正常工作,因为准确性没有提高。任何方向/提示,将不胜感激。

0 个答案:

没有答案