我在训练集中发现了一些错误的数据(错误标记的示例),虽然我已经修复了源代码,但我还是想继续尝试使用相同的数据集,所以我需要跳过这些记录。
我使用TFRecordReader并加载parse_single_example& shuffle_batch。我可以在某处提供过滤器吗?
答案 0 :(得分:4)
使用tf.train.shuffle_batch()
和enqueue_many=True
在docs中对如何执行此操作进行简短介绍。如果您可以确定示例是否使用图形操作进行了错误标记,那么您可以像这样过滤结果(改编自another SO answer):
X, y = tf.parse_single_example(...)
is_correctly_labelled = correctly_labelled(X, y)
X = tf.expand_dims(X, 0)
y = tf.expand_dims(y, 0)
empty = tf.constant([], tf.int32)
X, y = tf.cond(is_correctly_labelled,
lambda: [X, y],
lambda: [tf.gather(X, empty), tf.gather(y, empty)])
Xs, ys = tf.train.shuffle_batch(
[X, y], batch_size, capacity, min_after_dequeue,
enqueue_many=True)
tf.gather
只是一种获得零大小切片的方法。在numpy中它只是X[[], ...]
。