我正在使用自定义训练循环在TF2中训练Model
。我希望能够在给定的时刻保存优化状态,以便以后能够重新启动它。要保存的变量是模型参数,还包括优化变量,以及这里和那里的其他两个变量。
在TF1中,这实际上甚至不是问题,因为tf.train.Saver
将默认保存所有变量。
现在,如何在TF2中做到这一点?
根据指南,在TF2中,保存是通过Keras公开的功能(使用特定的回调或Model
方法)完成的。两种方法可以保存更多的信息,而不仅仅是网络参数,但是要实现此目的,需要使用tf.Model.compile
编译模型,以便将所有东西捆绑在一起。但是,使用自定义训练循环,就不会调用compile
。
那么当一个人不使用compile
/ fit
的正确路径时,如何保存我所有的变量以恢复训练呢?
答案 0 :(得分:1)
使用tf.train.Checkpoint
,然后将所有要保存的变量放入此函数。
tf.train.Checkpoint(model=model, optimizer=optimizer, [xx=xx])
更多详细信息,请参见此tf.train.Checkpoint