如何过滤存储在要素图中的数据点?

时间:2016-09-06 18:58:21

标签: tensorflow

我正在使用TF.Learn estimator进行预测。 fit method的数据作为输入函数传递,该函数返回一个特征映射 - 一个python字典映射特征名称到将它们从光盘中排出的张量:

def input_fn():
    feature_columns = get_feature_columns()
    features = tf.contrib.layers.create_feature_spec_for_parsing(feature_columns=feature_columns)
    feature_map = tf.contrib.learn.io.read_batch_features(
      file_pattern=data_dir,
      batch_size=BATCH_SIZE,
      features=features)
    target = feature_map.pop("target")
    return feature_map, target

我想根据一些谓词P来过滤数据,这样估算器就可以获得批量BATCH_SIZE中的点,但只有那些满足P的点。我怎样才能轻松实现这一目标?

(问题类似于:How to filter tensor from queue based on some predicate in tensorflow?,但你只过滤了一个张量)

1 个答案:

答案 0 :(得分:1)

使用过滤队列并使用queuerunner从read_batch_features的结果中取出单个元素,并根据您的谓词在过滤队列中有条件地将其排队。