GPT-2从检查站继续训练

时间:2020-02-06 14:51:59

标签: python tensorflow nlp google-colaboratory gpt-2

我正在尝试使用以下用于GPT-2-simple的colab设置从保存的检查点继续进行训练:

https://colab.research.google.com/drive/1SvQne5O_7hSdmPvUXl5UzPeG5A6csvRA#scrollTo=aeXshJM-Cuaf

但是我只是无法正常工作。从我的googledrive加载保存的检查点工作正常,我可以用它来生成文本,但是我无法从该检查点继续训练。在1中,我输入isNaN(...)gpt2.finetune (),并且试图同时使用相同的run_name和不同的run_name,而不使用restore.from='latest"。我也尝试过按照建议的那样重新启动运行时,但这没有帮助,我不断收到以下错误:

overwrite=True

我假设在继续训练之前需要先运行overwrite=True,但是无论何时我第一次运行,"ValueError: Variable model/wpe already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?" 都会引发此错误

2 个答案:

答案 0 :(得分:1)

在微调之前,您不需要(也不能)运行load_gpt2()。您只需将run_name赋予finetune()。我同意这令人困惑。我也遇到了同样的麻烦。

sess = gpt2.start_tf_sess()

gpt2.finetune(sess,
    file_name,
    model_name=model_name,
    checkpoint_dir=checkpoint_dir,
    run_name=run_name,
    steps=25,
)

这将自动从您的checkpoint/run-name文件夹中获取最新的检查点,加载其权重,并在中断的地方继续进行训练。您可以通过检查纪元编号来确认这一点-它不会从0再次开始。例如,如果您以前训练过25个纪元,则将从26开始:

Training...

[26 | 7.48] loss=0.49 avg=0.49

还要注意,要多次运行微调(或加载另一个模型),通常必须重新启动python运行时。您可以改为在每个finetine命令之前运行此命令:

tf.reset_default_graph()

答案 1 :(得分:1)

我尝试了以下方法并且工作正常:

tf.reset_default_graph()
sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
          steps=n,
          dataset=file_name,
          model_name='model', 
          print_every=z,
          run_name= 'run_name',
          restore_from='latest',
          sample_every=x,
          save_every=y
          )

您必须指定与您要继续训练的模型相同的“run_name”和 hp restore_from = 'latest'