我想知道在每次训练时代之后使用tf.train.Saver()保存我的模型时究竟会保存什么。与我以前使用的Keras模型相比,该文件似乎有点大。现在我的RNN每次保存需要900 MB。有没有办法告诉保护者只保存可训练的参数?我还想要一种只保存部分模型的方法。我知道我可以获取我定义的变量并使用numpy格式保存它们但是当我使用RNN类时我不能直接访问它们的权重并且我查看了代码并且没有像get_weights那样的东西我可以看到。
答案 0 :(得分:5)
您可以提供要在Saver
构造函数中保存的变量列表,即saver=tf.train.Saver(var_list=tf.trainable_variables())
答案 1 :(得分:4)
如果variables._all_saveable_objects()
未指定Saver
,则默认会保存所有var_list
。
也就是说,Saver
默认会保存所有全局变量和可保存变量。
def _all_saveable_objects():
"""Returns all variables and `SaveableObject`s that must be checkpointed.
Returns:
A list of `Variable` and `SaveableObject` to be checkpointed
"""
# TODO(andreasst): make this function public once things are settled.
return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) +
ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))