Tensorflow,tf.train.Saver保存了什么?

时间:2016-06-18 19:44:48

标签: tensorflow

我想知道在每次训练时代之后使用tf.train.Saver()保存我的模型时究竟会保存什么。与我以前使用的Keras模型相比,该文件似乎有点大。现在我的RNN每次保存需要900 MB。有没有办法告诉保护者只保存可训练的参数?我还想要一种只保存部分模型的方法。我知道我可以获取我定义的变量并使用numpy格式保存它们但是当我使用RNN类时我不能直接访问它们的权重并且我查看了代码并且没有像get_weights那样的东西我可以看到。

2 个答案:

答案 0 :(得分:5)

您可以提供要在Saver构造函数中保存的变量列表,即saver=tf.train.Saver(var_list=tf.trainable_variables())

答案 1 :(得分:4)

如果variables._all_saveable_objects()未指定Saver,则默认会保存所有var_list

也就是说,Saver默认会保存所有全局变量和可保存变量。

def _all_saveable_objects():
  """Returns all variables and `SaveableObject`s that must be checkpointed.

  Returns:
    A list of `Variable` and `SaveableObject` to be checkpointed
  """
  # TODO(andreasst): make this function public once things are settled.
  return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) +
          ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))