共享变量 - 两个队列

时间:2016-12-12 12:50:52

标签: tensorflow

感谢Tensorflow multithreading image loading,我有这个加载数据函数,给定一个csv文件,例如一个训练csv文件,它创建一些数据节点;

 34 def loadData(csvPath,shape, batchSize=10,batchCapacity=40,nThreads=16):
 35     path, label = readCsv(csvPath)
 36     labelOh = oneHot(idx=label)
 37     pathRe = tf.reshape(path,[1])
 38     
 39     # Define subgraph to take filename, read filename, decode and enqueue
 40     image_bytes = tf.read_file(path)
 41     decoded_img = tf.image.decode_jpeg(image_bytes)
 42     decoded_img = prepImg(decoded_img,shape=shape)
 43     imageQ = tf.FIFOQueue(128,[tf.float32,tf.float32,tf.string], shapes = [shape,[447],[1]])
 44     enQ_op = imageQ.enqueue([decoded_img,labelOh,pathRe])
 45     
 46     NUM_THREADS = nThreads
 47     Q = tf.train.QueueRunner(
 48             imageQ,
 49             [enQ_op]*NUM_THREADS,
 50             imageQ.close(),
 51             imageQ.close(cancel_pending_enqueues=True)
 52             )
 53 
 54     tf.train.add_queue_runner(Q)
 55     dQ = imageQ.dequeue()
 56     X, Y, Ypaths = tf.train.batch(dQ, batch_size = batchSize, capacity = batchCapacity)
 57     return X, Y, Ypaths

然后我调用它并拥有标准模型,损失,训练子图如;

xTr, yTr, yPathsTr = loadData(trainCsvPath, *args)
yPredTr = model1(xTr,*args)
ce = ... # some loss function
learningRate = tf.placeholder(tf.float32)
trainStep = tf.train.AdamOptimizer(learningRate).minimize(ce)

然后我继续训练模型中的权重。据我所知,到目前为止,我不需要将数据输入feed_dict,因为它已经定义了。

with tf.Session() as sess:
     coord = tf.train.Coordinator()
     threads = tf.train.start_queue_runners(sess=sess,coord=coord)
     while not coord.should_stop(): 
           sess.run([trainStep],feed_dict={learningRate:lr})

我的问题是现在;

结合火车/测试过程的最佳方法是什么?即一旦线程完成了训练csv文件,他们就会读取测试csv文件,然后我运行另一个会话,我有类似的东西;

xTe, yTe, yPathsTe = loadData(csvPathTe, *args)
yPredTe = model1(xTe,*args) #will this use the same weights as the trained network? Or am I defining another seperate subgraph?
ce = ... # redefined for yPredTe
while not coord.should_stop(): 
      ce.eval() # return losses

运行直到测试csv文件完成。

我将如何冲洗并重复这些步骤(可能改组我的训练集)一定数量的时期?我也应该有一个csv队列吗?

1 个答案:

答案 0 :(得分:3)

唉,目前这个问题没有好的答案。典型的评估工作流程涉及运行单独的流程,定期执行以下操作(例如evaluate() in cifar10_eval.py):

  1. 构建一个图表,其中包含一个知道评估集的输入管道,模型的副本,评估操作(如果有)和tf.train.Saver
  2. 创建新会话。
  3. 恢复该会话中培训流程编写的最新检查点。
  4. 运行测试操作(例如问题中的ce)并在Python中累积结果,直到获得tf.errors.OutOfRangeError
  5. 我们目前正致力于改进输入管道,这样可以更轻松地多次迭代文件,并重复使用相同的会话。