TFRecordReader在数据输入管道中

时间:2016-08-28 07:29:12

标签: python queue tensorflow

我目前遇到了TFRecordReader

的实施问题

这是设置:

trainQ = tf.train.string_input_producer(fileList)
RecReader = tf.TFRecordReader()
batch_strings = RecReader.read(trainQ)
con,seq=tf.parse_single_sequence_example(batch_strings.value,context_features=lengths_context,sequence_features=convo_pair,name='parse_ex')
encoder_inputs,decoder_inputs,enc_len,dec_len = seq['utterance'],seq['response'],con['utter_length'],con['resp_length']
mini_batch = tf.train.batch([encoder_inputs,decoder_inputs,enc_len,dec_len,decoder_inputs],batch_size,2,capacity=50*batch_size,dynamic_pad = True,enqueue_many=False)
encoder_inp,decoder_inp,encoder_lens,decoder_lens,labels = mini_batch
...
<build rest of the model>
...
loss = <some loss>
train_ops = <optimizer>.minimize(loss)

现在当我train_ops.run()时,它会自动读取队列并在一批中训练模型。但是如果我想评估一些中间变量,我就不能variable.eval(),因为这意味着从trainQ队列读取一个新的批处理具有不同的值

我可以考虑绕过这种方法来使用占位符来提供parse_single_example并在每次火车循环中填充占位符。但有没有更好的方法来做到这一点,即评估变量而不再读取队列?

希望这不会令人困惑

1 个答案:

答案 0 :(得分:0)

如果要评估依赖于每100次迭代的批输入的中间层(称为conv3),您可以执行以下操作:

for step in range(100000):
    if step % 100 != 0:
        # only run the training operation
        sess.run(train_op)
    else:
        # run train_op AND `conv3` at the same time
        _, conv3_value = sess.run([train_op, conv3])

这里的诀窍是在train_op的同一个电话中拨打conv3tf.Session。这样,从训练队列中读取批处理以训练一步,但同时也使用它来计算conv3