将新存储桶添加到Tensorflow的seq2seq_model中

时间:2016-06-17 18:01:23

标签: tensorflow

我正在学习如何使用Tensorflow seq2seq_model(https://www.tensorflow.org/versions/r0.9/tutorials/seq2seq/index.html)。我在训练期间遇到了一个问题

seq2seq_model.Seq2SeqModel(..., listOfBuckets, ...)

需要永远,因为存储桶将其列为大型。如果我通过扩展每个存储桶边框来尝试使用较小的列表,则对model.step(..)的调用将永远进行。

我的解决方案是有一个循环,其中每次迭代都会创建一个新的Seq2SeqModel并将参数保存在我用来在下一次迭代中初始化Seq2SeqModel的文件中。看起来像这样:

While cond:
   with tf.Session() as sess:
      model=seq2seq_model.Seq2SeqModel(..., listOfBuckets, ...)
      ckpt = tf.train.get_checkpoint_state(training_dir)
      if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
           print("Reading model parameters from %s"   %ckpt.model_checkpoint_path)
           model.saver.restore(session, ckpt.model_checkpoint_path)

      doSomething()

      checkpoint_path = os.path.join(training_dir, "model.ckpt")
      model.saver.save(sess, checkpoint_path,global_step=model.global_step)

   listOfBuckets = someNewlistOfBuckets

但这似乎是一个糟糕的解决方案,所以我的问题是是否有任何方法可以向模型添加新桶(在它已经创建之后),而不是一遍又一遍地创建它。

感谢。

1 个答案:

答案 0 :(得分:0)

确实,构建存储桶可能需要一段时间。现在TensorFlow中有动态图,因此您原则上可以为所有存储桶构建单个图。但这是一个很大的API更改,所以我们还没有对seq2seq做过。

另一方面,很奇怪使用较大的水桶会减慢你的下降速度。通常,步长的时间是线性的,因此使用两倍大的桶只会使事情减慢约2倍。然后,您只能将桶用于2的幂并获得一个非常合理的系统。你试过吗?也许你的步骤因其他原因而缓慢?