如何在tensorflow中只保存必要的参数?

时间:2016-08-17 20:24:23

标签: tensorflow

默认情况下,会话保护程序会保存所有创建的变量,从而导致检查点文件非常大。我想只保存模型参数和某些会话变量,例如优化器状态和全局步骤。在保护程序初始化期间,除了白名单变量之外,最佳做法是什么?

3 个答案:

答案 0 :(得分:1)

默认情况下,保护程序从all_variables()获取变量列表,这是GraphKeys.VARIABLES集合中的所有变量。您可以使用Variable(..., collections=[])从该集合中排除变量。或者你可以把它作为另一个集合,就像代码库中为非检查点limit_epochs变量

所做的那样
 with ops.name_scope(name, "limit_epochs", [tensor]) as name:
    zero64 = constant_op.constant(0, dtype=dtypes.int64)
    epochs = variables.Variable(
        zero64, name="epochs", trainable=False,
        collections=[ops.GraphKeys.LOCAL_VARIABLES])

答案 1 :(得分:1)

经过一些调查(检查点不同的批量大小并打印出all_variables),我发现我过度担心。实际上,在张量流中,Op的结果不被保存,例如, y中的y = k * x + b。因此,与torch-nn不同,您很少需要担心非参数被保存。

答案 2 :(得分:0)

您可以创建一个包含要保存的所有变量的字典,其中键的名称为字符串。将此字典传递给saver.save()函数。 这就是api建议的内容。