Tensorflow Eager:存储和恢复可训练变量

时间:2018-04-28 03:12:49

标签: tensorflow

我用Tensorflow Eager编写了一个自定义模型(类似于此example)。我想存储/恢复我的可训练变量 - 类似于以下非渴望逻辑。我怎样才能在Eager中做到这一点?:

def store(self, sess_var, model_path):
    if model_path is not None:
        saver = tf.train.Saver(var_list=tf.trainable_variables())
        save_path = saver.save(sess_var, model_path)
        print("Model saved in path: %s" % save_path)
    else:
        print("Model path is None - Nothing to store")

def restore(self, sess_var, model_path):
    if model_path is not None:
        if os.path.exists("{}.index".format(model_path)):
            saver = tf.train.Saver(var_list=tf.trainable_variables())
            saver.restore(sess_var, model_path)
            print("Model at %s restored" % model_path)
        else:
            print("Model path does not exist, skipping...")
    else:
        print("Model path is None - Nothing to restore")

1 个答案:

答案 0 :(得分:1)

TensorFlow中的急切执行鼓励在对象中封装模型状态,例如在tf.keras.Model个对象中。然后可以使用tf.contrib.eager.Checkpoint

保存和恢复这些对象的状态("检查点"变量值)

请注意,tf.contrib.eager.Checkpoint类与eager和graph执行兼容。

您会在张量流存储库中的示例中看到这一点,例如thisthis

希望有所帮助。