默认情况下,会话保护程序会保存所有创建的变量,从而导致检查点文件非常大。我想只保存模型参数和某些会话变量,例如优化器状态和全局步骤。在保护程序初始化期间,除了白名单变量之外,最佳做法是什么?
答案 0 :(得分:1)
默认情况下,保护程序从all_variables()
获取变量列表,这是GraphKeys.VARIABLES
集合中的所有变量。您可以使用Variable(..., collections=[]
)从该集合中排除变量。或者你可以把它作为另一个集合,就像代码库中为非检查点limit_epochs
变量
with ops.name_scope(name, "limit_epochs", [tensor]) as name:
zero64 = constant_op.constant(0, dtype=dtypes.int64)
epochs = variables.Variable(
zero64, name="epochs", trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES])
答案 1 :(得分:1)
经过一些调查(检查点不同的批量大小并打印出all_variables
),我发现我过度担心。实际上,在张量流中,Op的结果不被保存,例如, y
中的y = k * x + b
。因此,与torch-nn
不同,您很少需要担心非参数被保存。
答案 2 :(得分:0)
您可以创建一个包含要保存的所有变量的字典,其中键的名称为字符串。将此字典传递给saver.save()函数。 这就是api建议的内容。