在TensorFlow中保存特定权重

时间:2016-09-12 12:12:46

标签: tensorflow

在我的神经网络中,我创建了一些tf.Variable个对象,如下所示:

weights = {
    'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
    'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
    'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
    'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}

如何在经过特定次数的迭代后保存变量weightsbiases而不保存其他变量?

1 个答案:

答案 0 :(得分:7)

在TensorFlow中保存变量的标准方法是使用tf.train.Saver对象。默认情况下,它会保存问题中的所有变量(即tf.all_variables()的结果),但您可以通过将var_list可选参数传递给tf.train.Saver构造函数来有选择地保存变量:< / p>

weights = {
    'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
    'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
    'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
    'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}

# Define savers for explicit subsets of the variables.
weights_saver = tf.train.Saver(var_list=weights)
biases_saver = tf.train.Saver(var_list=biases)

# ...
# You need a TensorFlow Session to save variables.
sess = tf.Session()
# ...

# ...then call the following methods as appropriate:
weights_saver.save(sess)  # Save the current value of the weights.
biases_saver.save(sess)   # Save the current value of the biases.

请注意,如果您将字典传递给tf.train.Saver构造函数(例如问题中的weights和/或biases字典),TensorFlow将使用字典键(例如{ {1}})作为其创建或使用的任何检查点文件中相应变量的名称

默认情况下,或者如果将'wc1_0'个对象列表传递给构造函数,TensorFlow将使用tf.Variable属性。

传递字典使您能够在模型之间共享检查点,这些模型为每个变量提供不同的tf.Variable.name属性。 仅当您要将创建的检查点与其他模型一起使用时,此详细信息才很重要。