我正在使用TF.Learn estimator进行预测。 fit method的数据作为输入函数传递,该函数返回一个特征映射 - 一个python字典映射特征名称到将它们从光盘中排出的张量:
def input_fn():
feature_columns = get_feature_columns()
features = tf.contrib.layers.create_feature_spec_for_parsing(feature_columns=feature_columns)
feature_map = tf.contrib.learn.io.read_batch_features(
file_pattern=data_dir,
batch_size=BATCH_SIZE,
features=features)
target = feature_map.pop("target")
return feature_map, target
我想根据一些谓词P
来过滤数据,这样估算器就可以获得批量BATCH_SIZE中的点,但只有那些满足P
的点。我怎样才能轻松实现这一目标?
(问题类似于:How to filter tensor from queue based on some predicate in tensorflow?,但你只过滤了一个张量)
答案 0 :(得分:1)
使用过滤队列并使用queuerunner从read_batch_features的结果中取出单个元素,并根据您的谓词在过滤队列中有条件地将其排队。