TensorFlow出色的Dataset抽象可以使用带有谓词的过滤:
filter filter(predicate)根据谓词过滤此数据集。
Args:谓词:映射张量的嵌套结构的函数 (具有由self.output_shapes和 self.output_types)转换为标量tf.bool张量。
这非常强大;因为该谓词可以过滤数据集内容。
问题是:是否可能有“相反”的过滤条件:例如过采样?
table_breaktime
似乎不可行,因为它不依赖于数据集内容:
take take(count)创建一个数据集,其中最多包含来自 该数据集。
Args:count:一个tf.int64标量tf.Tensor,代表 应采用此数据集的元素来形成新数据集。 如果count为-1,或者count大于此数据集的大小, 新的数据集将包含该数据集的所有元素。
答案 0 :(得分:2)
TensorFlow当前不提供此类功能,但是您可以使用flat_map
来获得所需的结果。在这种情况下,您将为输入数据集的每个元素创建一个新的数据集(tf.data.Dataset.from_tensors
),该数据集将生成该单个样本(.repeat
)的多个副本。
例如:
import numpy as np
import tensorflow as tf
def run(dataset):
el = dataset.make_one_shot_iterator().get_next()
vals = []
with tf.Session() as sess:
try:
while True:
vals.append(sess.run(el))
except tf.errors.OutOfRangeError:
pass
return vals
dataset = tf.data.Dataset.from_tensor_slices((np.array([1,2,3,4,5]), np.array([5,4,3,2,1])))
print('Original dataset with repeats')
print(run(dataset))
dataset = dataset.flat_map(lambda v, r: tf.data.Dataset.from_tensors(v).repeat(r))
print('Repeats flattened')
print(run(dataset))
将打印
Original dataset with repeats
[(1, 5), (2, 4), (3, 3), (4, 2), (5, 1)]
Repeats flattened
[1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5]
或者,您可以使用.interleave
获得相同的结果,但可以混合多个样本的副本(.flat_map
是.interleave
的特殊情况)。例如:
dataset = tf.data.Dataset.from_tensor_slices((np.array([1,2,3,4,5]), np.array([5,4,3,2,1])))
dataset = dataset.interleave(lambda v, r: tf.data.Dataset.from_tensors(v).repeat(r), 4, 1)
print('Repeats flattened with a little bit of deterministic mixing')
print(run(dataset))
将打印
Repeats flattened with a little bit of deterministic mixing
[1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 1, 2, 5, 1]