使用TensorFlow进行增量训练

时间:2016-09-28 14:07:14

标签: tensorflow

我想训练一个模型来对90K标签进行分类,所以我使用了所谓的增量训练。

我最初训练模型仅对1K标签进行分类,然后添加另外1K标签并将最终FC层的输出尺寸扩展到2K,并训练更多的纪元。之后我添加了另外1K标签,依此类推......

请注意,它不是微调,其中最后一个FC之前的所有参数都是固定的,因此我可以缓存输出功能。在我的情况下,我需要更新每个阶段的所有变量。

我设计的解决方案是:

  1. 训练1K标签。
  2. 保存模型。
  3. 修改图表,让最后一个FC层输出2K维。
  4. 初始化所有变量
  5. 加载上一个检查点,它将覆盖所有参数,但是最后一层的权重。
  6. 再次训练并重复
  7. 所以这里的关键点是实现部分恢复检查点。

    在TensorFlow中,我使用此类代码加载检查点:

    saver.restore(sess, "model.ckpt")
    

    但是,当形状不匹配时它会失败。

    任何人都可以提供帮助,无论是部分恢复/初始化变量,还是如何以其他方式实施增量培训?

1 个答案:

答案 0 :(得分:1)

目前这并不容易。我们正在积极添加新的API以使其更容易。

与此同时,如果您确定,:),您可以在更改FC图层的大小时尝试以下操作:

  • 创建一个读者: reader = tf.train.NewCheckpointReader(your_checkpoint_file)
  • 在检查点文件中加载所有变量: cur_vars = reader.get_variable_to_shape_map()。keys()
  • 删除原始FC图层: cur_vars_without_fc = cur_vars - your_fc_layer_var_name
  • 使用以下变量创建保护程序: saver = tf.Saver(cur_vars_without_fc) saver.restore(sess,your_checkpoint_file)
  • 初始化新的FC图层变量: sess.run([your_fc_layer_var.initializer])

希望有所帮助!

雪利酒