我用Tensorflow Eager编写了一个自定义模型(类似于此example)。我想存储/恢复我的可训练变量 - 类似于以下非渴望逻辑。我怎样才能在Eager中做到这一点?:
def store(self, sess_var, model_path):
if model_path is not None:
saver = tf.train.Saver(var_list=tf.trainable_variables())
save_path = saver.save(sess_var, model_path)
print("Model saved in path: %s" % save_path)
else:
print("Model path is None - Nothing to store")
def restore(self, sess_var, model_path):
if model_path is not None:
if os.path.exists("{}.index".format(model_path)):
saver = tf.train.Saver(var_list=tf.trainable_variables())
saver.restore(sess_var, model_path)
print("Model at %s restored" % model_path)
else:
print("Model path does not exist, skipping...")
else:
print("Model path is None - Nothing to restore")
答案 0 :(得分:1)
TensorFlow中的急切执行鼓励在对象中封装模型状态,例如在tf.keras.Model
个对象中。然后可以使用tf.contrib.eager.Checkpoint
请注意,tf.contrib.eager.Checkpoint
类与eager和graph执行兼容。
您会在张量流存储库中的示例中看到这一点,例如this和this
希望有所帮助。