当获得`DirectedInterleave选择了疲惫的输入`警告时,TensorFlow的`sample_from_datasets`是否仍从数据集中采样吗?

时间:2019-07-30 18:44:21

标签: python tensorflow tensorflow-datasets tensorflow2.0

当使用TensorFlow的tf.data.experimental.sample_from_datasets同样地从两个非常不平衡的数据集中采样时,我最终收到了DirectedInterleave selected an exhausted input: 0警告。基于this GitHub issue,当sample_from_datasets内的一个数据集中的所有示例都用光了,并且需要对已经看到的示例进行采样时,似乎就会发生这种情况。

耗尽的数据集是否仍会产生样本(从而保持所需的平衡训练比率),还是该数据集没有采样,因此训练再次变得不平衡?如果是后者,是否有一种方法可以用sample_from_datasets产生所需的平衡训练率?

注意:正在使用TensorFlow 2 Beta

1 个答案:

答案 0 :(得分:2)

较小的数据集不会重复-一旦用尽,其余的将仅来自仍然具有示例的较大的数据集。

您可以通过执行以下操作来验证此行为:

def data1():
  for i in range(5):
    yield "data1-{}".format(i)

def data2():
  for i in range(10000):
    yield "data2-{}".format(i)

ds1 = tf.data.Dataset.from_generator(data1, tf.string)
ds2 = tf.data.Dataset.from_generator(data2, tf.string)

sampled_ds = tf.data.experimental.sample_from_datasets([ds2, ds1], seed=1)

然后,如果我们遍历sampled_ds,我们将发现data1用尽后不会产生任何样本:

tf.Tensor(b'data1-0', shape=(), dtype=string)
tf.Tensor(b'data2-0', shape=(), dtype=string)
tf.Tensor(b'data2-1', shape=(), dtype=string)
tf.Tensor(b'data2-2', shape=(), dtype=string)
tf.Tensor(b'data2-3', shape=(), dtype=string)
tf.Tensor(b'data2-4', shape=(), dtype=string)
tf.Tensor(b'data1-1', shape=(), dtype=string)
tf.Tensor(b'data1-2', shape=(), dtype=string)
tf.Tensor(b'data1-3', shape=(), dtype=string)
tf.Tensor(b'data2-5', shape=(), dtype=string)
tf.Tensor(b'data1-4', shape=(), dtype=string)
tf.Tensor(b'data2-6', shape=(), dtype=string)
tf.Tensor(b'data2-7', shape=(), dtype=string)
tf.Tensor(b'data2-8', shape=(), dtype=string)
tf.Tensor(b'data2-9', shape=(), dtype=string)
tf.Tensor(b'data2-10', shape=(), dtype=string)
tf.Tensor(b'data2-11', shape=(), dtype=string)
tf.Tensor(b'data2-12', shape=(), dtype=string)
...
---[no more 'data1-x' examples]--
...

当然,您可以 像这样重复data1

sampled_ds = tf.data.experimental.sample_from_datasets([ds2, ds1.repeat()], seed=1)

但从评论看来,您已经意识到这一点,并且不适用于您的情况。

  

如果是后者,是否有一种方法可以使用sample_from_datasets产生所需的平衡训练率?

好吧,如果您有2个长度不同的数据集,并且从此开始均匀采样,那么看来您只有2个选择:

  • 重复较小的数据集n次(其中n ≃ len(ds2)/len(ds1)
  • 在较小的数据集用完后停止采样

要实现第一个目标,您可以使用ds1.repeat(n)

要获得秒数,可以使用ds2.take(m),其中m=len(ds1)