我使用tf.contrib.learn.ReadBatchFeatures
(https://www.tensorflow.org/versions/master/api_docs/python/contrib.learn/input_processing#read_batch_features)作为输入函数的一部分读取Example
个原语,返回Tensor
个对象的字典。训练我的模型后,在predict
上调用Estimator
会将一批预测作为数组返回,我想将其与已知值进行比较。
我尝试通过调用tf.Session().run(labels)
来获取已知值,其中labels
是已知值的Tensor
,从输入函数返回。但是,此时,我的程序挂起了。我怀疑它是从磁盘读取标签的无限循环,而不是只是按照我的意愿阅读一批。
这是在labels
Tensor
中获取一批值的正确方法吗?
_, labels = eval_input_fn()
with tf.Session().as_default():
tf.local_variables_initializer()
tf.train.start_queue_runners()
label_values = labels.eval()
print(label_values)
答案 0 :(得分:2)
您需要的整个设置是:
TypeScript