如何使用tf.Dataset

时间:2019-05-24 17:57:57

标签: python tensorflow dataset shuffle tensorflow-datasets

我有几个TFRecords,每个记录中包含数千个示例,每个标签都是唯一的,并且是从上一个(1、2、3、4、5等)自动递增的。我的目标是将所有这些连接在一起,以特定的方式整理示例,然后进行示例和培训。我找到了一种解决方法,但最后似乎无法正常工作(没有按我所希望的改组)。

这里是这种情况:假设我总共有100个示例,混洗过程必须交换一些大小(例如4号)的块,最后用更大大小的批次(例如16个)进行训练)。 理想的情况是将这些块彼此之间的距离并不是真的很远,而是要使它们中的两个都不是相邻的(在为训练采样时)。因此,我总是选择相对较小的随机缓冲区大小(此缓冲区可能是我的问题所在吗?)。

我已经找到了一种方法来实现这一目标,但是显然,当使用越来越多的数据时,紧密的“块”(按标签)的数量会增加,甚至会得到非常紧密的块(然后是闭合标签)的很高的峰值。 / p>

这是我实施此变通办法的方式:

# Loading all the TFRecords and concatenating them...

dataset = dataset.batch(label_frames_tolerance)
dataset = dataset.shuffle(batch_shuffling_buffer_size)
dataset = dataset.apply(tf.data.experimental.unbatch())
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)

# Train...

按照我上面描述的示例,在这里您可以想象:

label_frames_tolerance = 4
batch_shuffling_buffer_size = 10
num_epochs = 10
batch_size = 16

仅1个时期的可能结果就是这样的:

[29 30 31 32 37 38 39 40 17 18 19 20 33 34 35 36  1  2  3  4 25 26 27]
[28 61 62 63 64 13 14 15 16 21 22 23 24 73 74 75 76 53 54 55 56 77 78]
[ 79  80  49  50  51  52   5   6   7   8  65  66  67  68  81  82  83  84
  97  98  99 100  45]
[46 47 48 89 90 91 92  9 10 11 12 57 58 59 60 85 86 87 88 69 70 71 72]
[93 94 95 96 41 42 43 44]

(由于数据量大,最后一个“破碎”批次不影响训练)

0 个答案:

没有答案