从带有谓词的tf.data.Dataset中获取(如过滤器)

时间:2018-06-22 05:24:59

标签: tensorflow machine-learning deep-learning

TensorFlow出色的Dataset抽象可以使用带有谓词的过滤:

  

filter filter(predicate)根据谓词过滤此数据集。

     

Args:谓词:映射张量的嵌套结构的函数   (具有由self.output_shapes和   self.output_types)转换为标量tf.bool张量。

这非常强大;因为该谓词可以过滤数据集内容。

问题是:是否可能有“相反”的过滤条件:例如过采样?

table_breaktime似乎不可行,因为它不依赖于数据集内容:

  

take take(count)创建一个数据集,其中最多包含来自   该数据集。

     

Args:count:一个tf.int64标量tf.Tensor,代表   应采用此数据集的元素来形成新数据集。   如果count为-1​​,或者count大于此数据集的大小,   新的数据集将包含该数据集的所有元素。

1 个答案:

答案 0 :(得分:2)

TensorFlow当前不提供此类功能,但是您可以使用flat_map来获得所需的结果。在这种情况下,您将为输入数据集的每个元素创建一个新的数据集(tf.data.Dataset.from_tensors),该数据集将生成该单个样本(.repeat)的多个副本。

例如:

import numpy as np
import tensorflow as tf

def run(dataset):
    el = dataset.make_one_shot_iterator().get_next()
    vals = []
    with tf.Session() as sess:
        try:
            while True:
                vals.append(sess.run(el))
        except tf.errors.OutOfRangeError:
            pass

    return vals

dataset = tf.data.Dataset.from_tensor_slices((np.array([1,2,3,4,5]), np.array([5,4,3,2,1])))
print('Original dataset with repeats')
print(run(dataset))

dataset = dataset.flat_map(lambda v, r: tf.data.Dataset.from_tensors(v).repeat(r))
print('Repeats flattened')
print(run(dataset))

将打印

Original dataset with repeats
[(1, 5), (2, 4), (3, 3), (4, 2), (5, 1)]
Repeats flattened
[1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5]

或者,您可以使用.interleave获得相同的结果,但可以混合多个样本的副本(.flat_map.interleave的特殊情况)。例如:

dataset = tf.data.Dataset.from_tensor_slices((np.array([1,2,3,4,5]), np.array([5,4,3,2,1])))
dataset = dataset.interleave(lambda v, r: tf.data.Dataset.from_tensors(v).repeat(r), 4, 1)
print('Repeats flattened with a little bit of deterministic mixing')
print(run(dataset))

将打印

Repeats flattened with a little bit of deterministic mixing
[1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 1, 2, 5, 1]