在TensorFlow中展平数据集

时间:2018-04-21 22:53:13

标签: python tensorflow flatten tensor tensorflow-datasets

我正在尝试将TensorFlow中的数据集转换为具有多个单值张量。数据集目前看起来像这样:

<head>

转换后应该如下所示:

init

我最初的想法是在数据集上使用fbq,然后使用[12 43 64 34 45 2 13 54] [34 65 34 67 87 12 23 43] [23 53 23 1 5] ... [12] [43] [64] [34] [45] [2] [13] [54] [34] [65] [34] [67] [87] [12] ... 将每个张量转换为张量列表:

flat_map

然而,每个张量的形状仅部分已知(即reshape),这就是拆除操作失败的原因。还有什么方法可以让#34; concat&#34;不明确地迭代它们的不同张量?

1 个答案:

答案 0 :(得分:1)

您的解决方案非常接近,但Dataset.flat_map()采用的函数返回tf.data.Dataset对象,而不是张量列表。幸运的是,Dataset.from_tensor_slices()方法适用于您的用例,因为它可以将张量分割为可变数量的元素:

output_labels = self.dataset.flat_map(tf.data.Dataset.from_tensor_slices)

请注意,tf.contrib.data.unbatch()转换实现了相同的功能,并且在TensorFlow的当前主分支中具有稍高效的实现(将包含在1.9版本中):

output_labels = self.dataset.apply(tf.contrib.data.unbatch())