在TF2中使用自定义训练循环时,如何保存所有变量(不仅是净变量)以便能够继续训练?

时间:2019-12-19 09:13:42

标签: python tensorflow keras tensorflow2.0 tf.keras

我正在使用自定义训练循环在TF2中训练Model。我希望能够在给定的时刻保存优化状态,以便以后能够重新启动它。要保存的变量是模型参数,还包括优化变量,以及这里和那里的其他两个变量。

在TF1中,这实际上甚至不是问题,因为tf.train.Saver将默认保存所有变量。

现在,如何在TF2中做到这一点?

根据指南,在TF2中,保存是通过Keras公开的功能(使用特定的回调或Model方法)完成的。两种方法可以保存更多的信息,而不仅仅是网络参数,但是要实现此目的,需要使用tf.Model.compile编译模型,以便将所有东西捆绑在一起。但是,使用自定义训练循环,就不会调用compile

那么当一个人不使用compile / fit的正确路径时,如何保存我所有的变量以恢复训练呢?

1 个答案:

答案 0 :(得分:1)

使用tf.train.Checkpoint,然后将所有要保存的变量放入此函数。

tf.train.Checkpoint(model=model, optimizer=optimizer, [xx=xx])

更多详细信息,请参见此tf.train.Checkpoint