在我的神经网络中,我创建了一些tf.Variable
个对象,如下所示:
weights = {
'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}
如何在经过特定次数的迭代后保存变量weights
和biases
而不保存其他变量?
答案 0 :(得分:7)
在TensorFlow中保存变量的标准方法是使用tf.train.Saver
对象。默认情况下,它会保存问题中的所有变量(即tf.all_variables()
的结果),但您可以通过将var_list
可选参数传递给tf.train.Saver
构造函数来有选择地保存变量:< / p>
weights = {
'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}
# Define savers for explicit subsets of the variables.
weights_saver = tf.train.Saver(var_list=weights)
biases_saver = tf.train.Saver(var_list=biases)
# ...
# You need a TensorFlow Session to save variables.
sess = tf.Session()
# ...
# ...then call the following methods as appropriate:
weights_saver.save(sess) # Save the current value of the weights.
biases_saver.save(sess) # Save the current value of the biases.
请注意,如果您将字典传递给tf.train.Saver
构造函数(例如问题中的weights
和/或biases
字典),TensorFlow将使用字典键(例如{ {1}})作为其创建或使用的任何检查点文件中相应变量的名称。
默认情况下,或者如果将'wc1_0'
个对象列表传递给构造函数,TensorFlow将使用tf.Variable
属性。
传递字典使您能够在模型之间共享检查点,这些模型为每个变量提供不同的tf.Variable.name
属性。
仅当您要将创建的检查点与其他模型一起使用时,此详细信息才很重要。