Tensorflow tf.train.Saver可以保存可疑的大.ckpt文件?

时间:2016-02-11 06:42:52

标签: neural-network tensorflow conv-neural-network

我正在使用合理大小的网络(1个卷积层,2个完全连接的层)。每次使用tf.train.Saver保存变量时,.ckpt文件的每个磁盘空间都是半GB(准确地说是512 MB)。这是正常的吗?我有一个具有相同架构的Caffe网络,只需要一个7MB .caffemodel文件。 Tensorflow保存如此大的文件大小是否有特殊原因?

非常感谢。

3 个答案:

答案 0 :(得分:5)

很难说你的网络与你所描述的有多大 - 两个完全连接的层之间的连接数量与每个层的大小成比例地放大,所以也许你的网络非常大,具体取决于完全连接的图层。

如果您想在检查点文件中节省空间,可以替换此行:

saver = tf.train.Saver()

以下内容:

saver = tf.train.Saver(tf.trainable_variables())

默认情况下,tf.train.Saver()会保存图表中的所有变量 - 包括优化程序创建的用于累积渐变信息的变量。告诉它只保存 trainable 变量意味着它只会保存网络的权重和偏差,并丢弃累积的优化器状态。您的检查点可能会小很多,需要权衡的是,在您恢复训练后,您可能会遇到前几个训练批次的较慢训练,而优化程序会重新累积梯度信息。根据我的经验,我不需要很长时间才能恢复速度,所以我个人认为,对于较小的检查站来说,这种权衡是值得的。

答案 1 :(得分:1)

也许你可以尝试(在Tensorflow 1.0中):

saver.save(sess, filename, write_meta_graph=False)

不保存元图信息。 看到:  https://www.tensorflow.org/versions/master/api_docs/python/tf/train/Saver  https://www.tensorflow.org/programmers_guide/meta_graph

答案 2 :(得分:0)

通常,您只保存tf.global_variables()(这是tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)的简写,即全局变量的集合)。此集合旨在包含恢复模型状态所必需的变量,例如批量规范化的当前移动平均值,全局步骤,优化程序的状态,当然还有{{1}收集。更多临时性质的变量(例如渐变)收集在tf.GraphKeys.TRAINABLE_VARIABLES中,通常不需要存储它们,它们可能占用大量磁盘空间。