TensorFlow:读取数组的批处理功能

时间:2017-01-10 03:44:31

标签: tensorflow

我使用tf.contrib.learn.ReadBatchFeatureshttps://www.tensorflow.org/versions/master/api_docs/python/contrib.learn/input_processing#read_batch_features)作为输入函数的一部分读取Example个原语,返回Tensor个对象的字典。训练我的模型后,在predict上调用Estimator会将一批预测作为数组返回,我想将其与已知值进行比较。

我尝试通过调用tf.Session().run(labels)来获取已知值,其中labels是已知值的Tensor,从输入函数返回。但是,此时,我的程序挂起了。我怀疑它是从磁盘读取标签的无限循环,而不是只是按照我的意愿阅读一批。

这是在labels Tensor中获取一批值的正确方法吗?

编辑:我已经尝试启动队列运行器,以下是正确的吗?

_, labels = eval_input_fn()
with tf.Session().as_default():
  tf.local_variables_initializer()
  tf.train.start_queue_runners()
  label_values = labels.eval()
print(label_values)

1 个答案:

答案 0 :(得分:2)

您需要的整个设置是:

TypeScript