我正在寻找一种方法来保存tf.Estimator的检查点。 伪代码:
class model:
def create_model():
self.estimator=tf.Estimator(...)
def train():
self.estimator.train(..., steps=10)
def save(checkpoint_path):
#save chackpoint, every checkpoint may be saved in diffrent directory
pass
def restore(checkpoint_path):
# return estimator build from checkpoint
return self.estimator(checkpoint_path)