带输入管道的张量流检查点

时间:2017-06-09 12:00:42

标签: tensorflow

我们有以下输入管道:

with tf.name_scope('input'):
 filename_queue = tf.train.string_input_producer(
    [filename], num_epochs=num_epochs)

 # Even when reading in multiple threads, share the filename
 # queue.
 image, label = read_and_decode(filename_queue)

 # Shuffle the examples and collect them into batch_size batches.
 # (Internally uses a RandomShuffleQueue.)
 # We run this in two threads to avoid being a bottleneck.
 images, sparse_labels = tf.train.shuffle_batch(
    [image, label], batch_size=batch_size, num_threads=2,
    capacity=1000 + 3 * batch_size,
    # Ensures a minimum amount of shuffling of examples.
    min_after_dequeue=1000)

 return images, sparse_labels

我们接受以下培训:

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
  step = 0
  while not coord.should_stop():
    start_time = time.time()

    # Run one step of the model.  The return values are
    # the activations from the `train_op` (which is
    # discarded) and the `loss` op.  To inspect the values
    # of your ops or variables, you may include them in
    # the list passed to sess.run() and the value tensors
    # will be returned in the tuple from the call.
    _, loss_value = sess.run([train_op, loss])

    duration = time.time() - start_time

    # Print an overview fairly often.
    if step % 100 == 0:
      print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
                                                 duration))
    step += 1
except tf.errors.OutOfRangeError:
  print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
finally:
  # When done, ask the threads to stop.
  coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()

我有两个疑问:

1)变量num_epochs是否决定训练迭代次数? 2)我的模型很大,我想检查点,恢复和训练。   我如何知道恢复的模型完成了多少次迭代以及剩下多少次迭代?

1 个答案:

答案 0 :(得分:0)

1)如tensorflow api tf.train.string_input_producer中所述,每{...}}次生成每个字符串后,都会抛出一个tf.errors.OutOfRangeError。所以,是的,num_epoch将决定代码中的训练迭代次数。

2)我认为可能会声明num_epochs并为您运行的每个纪元增加其值,因此当您恢复模型时,您可以再次读取该值并训练剩余的纪元。不幸的是,我不知道是否有更聪明的方法,因为大多数人只是在训练后保存他们的模型进行预测,或者做一些固定数量的时期的微调。

希望我能提供帮助