TensorFlow:数据集的多线程分批处理

时间:2019-07-11 15:14:21

标签: tensorflow tensorflow-datasets tensorflow2.0

我正在使用TensorFlow 2.0 beta。我有一个TensorFlow Dataset,其中每个元素都是一批特征列:张量的元组,其中每个元素都有batch_size记录的特定特征的值。我需要将这些记录整理为TFRecords以进行序列化,我想使用TensorFlow Dataset函数来做到这一点。展平的记录不必按确定的顺序生成。

以下是一些示例代码,演示了我要完成的工作:

batch_size = 100
num_batches = 10
input_data = (tf.constant(['text_data']), tf.constant(13))
ds = tf.data.Dataset.from_tensors(input_data).repeat(batch_size * num_batches)
ds = ds.batch(batch_size)
# ds = ... (multithreaded data transformations on batches of records happen here)
ds = ds.unbatch()

问题是我尝试这样做的方法不起作用或形成主要瓶颈,因为它们是单线程的。以下是其中一些方法:

  1. unbatch-单线程,太慢了
  2. interleave / flat_map-flat_map不接受张量元组-“采用2个位置参数,但” [num_features]个“给定”
  3. interleave /带有py_function的自定义函数-不起作用,因为py_function无法返回Dataset
  4. interleave /不带py_function的自定义函数-不起作用,因为在图形模式下,无法迭代张量

我需要用某种方式将unbatch替换为将批处理分配给多个线程,然后将批处理独立地解除批处理,然后交错来自不同线程的结果。有什么想法吗?

1 个答案:

答案 0 :(得分:0)

这是我最终找到的版本,将interleavefrom_tensor_slices结合使用:

batch_size = 100
num_batches = 10
num_threads = 4
input_data = (tf.constant(['text_data']), tf.constant(13))
ds = tf.data.Dataset.from_tensors(input_data).repeat(batch_size * num_batches)
ds = ds.batch(batch_size)
# ds = ... (multithreaded data transformations on batches of records happen here)
ds = ds.interleave(lambda *args:tf.data.Dataset.from_tensor_slices(args), num_threads, 1, num_threads)