如何按需保存tf.estimator的检查点

时间:2019-01-21 18:05:31

标签: tensorflow

我正在寻找一种方法来保存tf.Estimator的检查点。 伪代码:

class model:
    def create_model():
        self.estimator=tf.Estimator(...)
    def train():
        self.estimator.train(..., steps=10)
    def save(checkpoint_path):
        #save chackpoint, every checkpoint may be saved in       diffrent directory
        pass
    def restore(checkpoint_path):
        # return estimator build from checkpoint
        return self.estimator(checkpoint_path)

0 个答案:

没有答案