如何使用TensorFlow对整个数据集进行混洗?

时间:2017-06-28 02:37:41

标签: tensorflow shuffle

现在我使用以下功能进行改组

from tensorflow.contrib import data
def input_pipeline(filenames, batch_size):
    # Define a `tf.contrib.data.Dataset` for iterating over one epoch of the data.
    dataset = data.TextLineDataset(filenames)
    dataset = dataset.map(decode_func)
    dataset = dataset.shuffle(buffer_size=10000)  # Equivalent to min_after_dequeue=10000.
    dataset = dataset.batch(batch_size)

    # Return an *initializable* iterator over the dataset, which will allow us to
    # re-initialize it at the beginning of each epoch.
    return dataset.make_initializable_iterator() 

但它只会以buffer_size的数量对数据进行随机播放,并且会在订单中填充buffer

我的数据非常庞大,我无法设置buffer_size太大。有没有其他解决方案可以改变整个数据集?

1 个答案:

答案 0 :(得分:1)

当前,Dataset API中不支持将整个数据集改组(大于10k的示例)。根据{{​​3}}线程,常见方法是:

  
      
  1. 使用   MapReduce / Spark / Beam / etc。创建一组大小大致相等的作业   文件(“碎片”)。
  2.   
  3. 在每个时期:

         

    a。使用Dataset.list_files(...)。shuffle(num_shards)随机调整分片文件名列表。

         

    b。使用dataset.interleave(lambda文件名:tf.data.TextLineDataset(文件名),cycle_length = N)将来自N个不同分片的记录混合在一起。

         

    c。使用dataset.shuffle(B)可以对结果数据集进行随机排序。设置B可能需要进行一些试验,但是您可能希望将其设置为比单个分片中的记录数更大的值。

  4.