TensorFlow保存问题

时间:2016-07-18 17:42:40

标签: python tensorflow

我正在使用TensorFlow训练神经网络,模型的训练正在使用批量梯度下降的自定义实现。我有一个记录验证错误的记录功能,它下降到约2.6%。我使用tf.train.Saver每隔10个纪元保存模型。

然而,当我使用具有相同脚本的tf.train.Saver再次将变量加载到内存中时,模型表现不佳 - 具有随机初始化权重时的性能。我检查了检查点中的宪法过滤器,但它们似乎并不是随机的。

我没有包含我的所有代码,因为它大约有400行,但我在这里列出了重要的部分,并总结了其他功能。

class ModelTrainer:

    def __init__(self, ...hyperparameters...):

        # Intitialize datasets and hyperparameters

        for each gpu
            # Create loss function and gradient assigned to this gpu using tf.device("/gpu:n")

        with tf.device("/cpu:0")
            # Average and clip gradients from the gpu's

            # Create this batch gradient descent operation for each trainable variable
            variable.assign_sub(learning_rate * averaged_and_clipped_gradient).op


    def train(self, ...hyperparameters...)

        saver = train.Saver(tf.all_variables(), max_to_keep = 30)
        init = tf.initialize_all_variables()
        sess = tf.Session()

        if starting_point is not None:  # Used to evaluate existing models
            saver.restore(sess, starting_point)
        else:
            sess.run(init)

        for i in range(number_of_batches)

            # ... Get training batch ...

            gradients = sess.run(calculate_gradients, feeds = training_batch)

            # Average "gradients" variable across multiple batches
            # Must be done because of GPU memory limitations

            if i % meta_batch_size == 0:
                sess.run(apply_gradients_operators,
                         feeds = gradients_that_have_been_averaged_across_multiple_batches)

            # Log validation error

            if i % save_after_n_batches == 0:
                saver.save(sess, "some-filename", global_step=self.iter_num)

正如预期的那样,运行这两个函数会创建一组名为“some-filename-40001”的检查点文件,或者保存该文件时训练所处的其他迭代次数。不幸的是,当我使用start_point参数加载这些检查点时,它们与随机初始化相同。

最初我认为这与我训练模型的方式有关,因为我没有找到其他人遇到此问题,但验证错误的行为与预期一致。

编辑:更奇怪的结果。经过更多实验,我发现当我使用代码加载保存的模型时:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph("saved-checkpoint-40.meta")
    saver.restore(sess, "saved-checkpoint-40")

    # ... Use model in some way ...

我变得不同,但结果仍然不正确。

0 个答案:

没有答案