训练与队列测试

时间:2016-11-25 10:13:51

标签: python machine-learning tensorflow

我正在使用here描述的设置分批加载一些训练图像,即基本上这样:

def read_my_file_format(filename_queue):
  # ... use a reader + a decoder

def input_pipeline(filenames, batch_size, num_epochs=None):
  filename_queue = tf.train.string_input_producer(...)
  example, label = read_my_file_format(filename_queue)
  example_batch, label_batch = tf.train.shuffle_batch(
      [example, label], batch_size=batch_size, ...)
  return example_batch, label_batch

def build_net():
    batch, label = input_pipeline(...)
    y = encoder(batch)  # <- build network using the batch

def train():
  with tf.Session() as sess:
    # ... init vars

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    try:
      while not coord.should_stop():
        # ... training step

    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        coord.request_stop()

    coord.join(threads)
    sess.close()

这对训练有好处 - 但是,我不知道如何测试最终的网络!令我困惑的是:

  • input_pipeline返回的张量是网络的一部分。为了测试,我将不得不更换它?
  • 我想我可以创建另一个input_pipeline进行测试,即使用不同的文件名队列。然后我可以使用tf.cond在不同的输入批次之间切换,但是:我如何确保一次只耗尽一个队列。我没有看到如何访问不同的队列以及如何指定它们的卸载方式。

基本上,这个问题归结为:测试使用tf.train.shuffle_batch方法构建的网络的规范方法是什么。

2 个答案:

答案 0 :(得分:1)

我的想法是使用字符串占位符,即假设您有多个输入文件:

filenames_place = tf.placeholder(tf.string, shape=[None])
num_epochs_place = tf.placeholder(tf.int32)
example_batch, label_batch = input_pipeline(filenames_place, batch_size, num_epochs_place)
...
try:
   sess.run(train_op, feed_dict={filenames_place: ["train_data1", "train_data2"], num_epochs_place=5})

except tf.errors.OutOfRangeError:
    print('Done training -- epoch limit reached')

sess.run(eval_op, feed_dict={filenames_place: ["test_data"], num_epochs_place=1})

答案 1 :(得分:1)

您肯定是在为数据集评估创建额外输入管道的想法。使用multiple input pipelines是推荐的方法之一,它包括两个过程 - 一个培训和另一个评估。在训练过程中将使用检查点,然后每千步,代码可以针对训练和测试数据集尝试eval模型。

引自文档:

  
      
  • 训练过程读取训练输入数据并定期写入包含所有训练变量的检查点文件。
  •   
  • 评估过程将检查点文件恢复为读取验证输入数据的推理模型。
  •   

即使在训练完成/退出后,也可以评估。 (see this example

另一个考虑因素是sharing variables列车和eval可以在同一个过程中在同一个图表中运行,同时分享他们训练过的变量!

关于队列耗尽问题,如果你用tf.train.shuffle_batch*将num_threads设置为大于1,它会同时从单个文件中读取(+比1个线程更快),而不是一次读取N个文件, (参见batching)部分。

相关问题