上下文
我已切换到数据集API(基于this),与使用队列相比,这导致了非常显着的性能提升。
我正在使用Tensorflow 1.6。
问题
我已根据非常有用的解释here实施了重新取样。
问题在于,无论我将重采样阶段放在输入管道中,程序都会返回 ResourceExhaustedError 。更改batch_size似乎无法解决此问题,只有在使用所有输入文件的一小部分时才会解决此问题。
我的训练文件(.tfrecords)大小约为200 GB,并且分成几百个分片,但到目前为止,数据集API已经很好地处理了它们,并且它只是导致此问题的重新采样。
输入管道示例
batch_size = 20000
dataset = tf.data.Dataset.list_files(file_list)
dataset = dataset.apply(tf.contrib.data.parallel_interleave(
tf.data.TFRecordDataset, cycle_length=len(file_list), sloppy=True, block_length=10))
if resample:
dataset = dataset.apply(tf.contrib.data.rejection_resample(class_func=class_func, target_dist=target_dist, initial_dist= initial_dist,seed=5))
dataset = dataset.map(lambda _, data: (data))
dataset = dataset.shuffle(5*batch_size,seed=5)
dataset = dataset.apply(tf.contrib.data.map_and_batch(
map_func=_parse_function, batch_size=batch_size, num_parallel_batches=8))
dataset = dataset.prefetch(10)
return dataset
如果有人知道如何解决这个问题,我们将不胜感激!