我正在使用TensorFlow 2.0 beta。我有一个TensorFlow Dataset
,其中每个元素都是一批特征列:张量的元组,其中每个元素都有batch_size
记录的特定特征的值。我需要将这些记录整理为TFRecords
以进行序列化,我想使用TensorFlow Dataset
函数来做到这一点。展平的记录不必按确定的顺序生成。
以下是一些示例代码,演示了我要完成的工作:
batch_size = 100
num_batches = 10
input_data = (tf.constant(['text_data']), tf.constant(13))
ds = tf.data.Dataset.from_tensors(input_data).repeat(batch_size * num_batches)
ds = ds.batch(batch_size)
# ds = ... (multithreaded data transformations on batches of records happen here)
ds = ds.unbatch()
问题是我尝试这样做的方法不起作用或形成主要瓶颈,因为它们是单线程的。以下是其中一些方法:
unbatch
-单线程,太慢了interleave
/ flat_map
-flat_map
不接受张量元组-“采用2个位置参数,但” [num_features]个“给定” interleave
/带有py_function
的自定义函数-不起作用,因为py_function
无法返回Dataset
interleave
/不带py_function
的自定义函数-不起作用,因为在图形模式下,无法迭代张量我需要用某种方式将unbatch
替换为将批处理分配给多个线程,然后将批处理独立地解除批处理,然后交错来自不同线程的结果。有什么想法吗?
答案 0 :(得分:0)
这是我最终找到的版本,将interleave
与from_tensor_slices
结合使用:
batch_size = 100
num_batches = 10
num_threads = 4
input_data = (tf.constant(['text_data']), tf.constant(13))
ds = tf.data.Dataset.from_tensors(input_data).repeat(batch_size * num_batches)
ds = ds.batch(batch_size)
# ds = ... (multithreaded data transformations on batches of records happen here)
ds = ds.interleave(lambda *args:tf.data.Dataset.from_tensor_slices(args), num_threads, 1, num_threads)