如何为语言翻译重新训练序列到序列神经网络模型?

时间:2019-02-20 06:45:01

标签: python tensorflow recurrent-neural-network language-translation seq2seq

我已经训练了seq2seq张量流模型来将句子从英语翻译成西班牙语。我训练了615700个步骤的模型,并成功保存了模型检查点。我的英语和西班牙语句子的训练数据量都是20万。我想从615700个步骤中重新训练10K个新数据句子的模型。我正在为此使用序列对Tensoflow模型进行序列化。如何从上一个检查点开始重新训练模型? Here是我用于翻译的链接。

我的火车文件夹中有3种文件类型:

RunAsync

我新的训练数据集文件分别是.index .meta .data and checkpoint file. europarl_train.es-en.en,分别用于英语和西班牙语句子。

我编写了代码以加载模型.meta文件和权重

europarl_train.es-en.es

如何开始保留此数据集?

1 个答案:

答案 0 :(得分:0)

保存

根据TensorFlow version 2 doc,您可以使用tf.train.Checkpointtf.train.CheckpointManager类来保存模型。 考虑以下示例:

checkpoint_dir = './training_checkpoints'       # custom directory
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(model=model)   # your model variable name
manager = tf.train.CheckpointManager(checkpoint=checkpoint, directory=checkpoint_dir, max_to_keep=3)           # max_to_keep means how much of last checkpoints number you like to keep

现在,如果您想保存模型,请输入:manager.save()

加载

再次定义checkpoint和checkpointManager并运行以下代码:

if manager.latest_checkpoint:
    checkpoint.restore((manager.latest_checkpoint)).assert_consumed()
    print("Restored from {}".format(manager.latest_checkpoint))

如果遇到类似(AssertionError:检查点(根)中未解决的对象)的错误,请用assert_consumed替换expect_partial。 (此处为差异:link

模型已从检查点加载。 现在,您可以加载数据并修复形状并继续训练模型。