在Tensorflow中保存AdaGrad算法的状态

时间:2016-11-11 11:36:47

标签: tensorflow

我正在尝试训练word2vec模型,并希望将嵌入用于其他应用程序。由于以后可能会有额外的数据,而且我的计算机在训练时速度很慢,我希望我的脚本能够停止并在以后恢复训练。

为此,我创建了一个保护程序:

saver = tf.train.Saver({"embeddings": embeddings,"embeddings_softmax_weights":softmax_weights,"embeddings_softmax_biases":softmax_biases})

我保存了嵌入,softmax权重和偏差,以便我可以在以后恢复训练。 (我认为这是正确的方法,但如果我错了请纠正我)。

不幸的是,当使用此脚本恢复训练时,平均损失似乎再次上升。

我的想法是,这可以归因于我正在使用的AdaGradOptimizer。最初,外部产品矩阵可能会被设置为全零,在我的训练之后它将被填充(导致较低的学习率)。

有没有办法保存优化器状态以便以后继续学习?

1 个答案:

答案 0 :(得分:6)

当您尝试直接序列化优化器对象(例如通过tf.add_to_collection("optimizers", optimizer)并随后调用tf.train.Saver().save())时,TensorFlow似乎会抱怨,您可以保存并恢复从优化器派生的训练更新操作:

# init
if not load_model:
    optimizer = tf.train.AdamOptimizer(1e-4)
    train_step = optimizer.minimize(loss)
    tf.add_to_collection("train_step", train_step)
else:
    saver = tf.train.import_meta_graph(modelfile+ '.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    train_step = tf.get_collection("train_step")[0]

# training loop
while training:
    if iteration % save_interval == 0:
        saver = tf.train.Saver()
        save_path = saver.save(sess, filepath)

我不知道获取或设置特定于现有优化器的参数的方法,因此我没有直接的方法来验证优化器的内部状态是否已恢复,但是培训恢复的损失和准确性与快照已创建。 我还建议使用对Saver()的无参数调用,以便仍然保存未明确提及的状态变量,尽管这可能不是绝对必要的。

您可能还希望保存迭代或纪元号以便以后恢复,如本例所示: http://www.seaandsailor.com/tensorflow-checkpointing.html