如何释放Tensorflow 2.0数据集的批处理

时间:2020-07-08 14:54:32

标签: python tensorflow machine-learning keras tensorflow-datasets

我有一个数据集,该数据集是通过与tf.data.Dataset一起使用的以下代码创建的:

dataset = Dataset.from_tensor_slices(corona_new)
dataset = dataset.window(WINDOW_SIZE, 1, drop_remainder=True)
dataset = dataset.flat_map(lambda x: x.batch(WINDOW_SIZE))
dataset = dataset.map(lambda x: tf.transpose(x))

for i in dataset:
    print(i.numpy())
    break

当我运行它时,我得到以下输出(这是一个批处理的示例):

[[  0. 125. 111. 232. 164. 134. 235. 190.] 
 [  0.  14.  16.   7.   9.   7.   6.   8.]
 [  0. 132. 199. 158. 148. 141. 179. 174.]
 [  0.   0.   0.   2.   0.   2.   1.   2.]
 [  0.   0.   0.   0.   3.   5.   0.   0.]]

我该如何对它们进行批处理?

1 个答案:

答案 0 :(得分:0)

找到我的解决方案。

在TensorFlow 2.0中,您可以通过调用tf.data.Dataset函数来解除.unbatch()的批处理。

示例:dataset.unbatch()