如何在TensorFlow 2.0中使用padded_batch()

时间:2019-12-21 00:45:22

标签: tensorflow2.0

X = tf.range(10)
dataset = tf.data.Dataset.from_tensor_slices(X)
dataset2 = dataset.repeat(3).padded_batch(7, padded_shapes=([]))
for item in dataset2:
    print(item)

输出

tf.Tensor([0 1 2 3 4 5 6], shape=(7,), dtype=int32)
tf.Tensor([7 8 9 0 1 2 3], shape=(7,), dtype=int32)
tf.Tensor([4 5 6 7 8 9 0], shape=(7,), dtype=int32)
tf.Tensor([1 2 3 4 5 6 7], shape=(7,), dtype=int32)
tf.Tensor([8 9], shape=(2,), dtype=int32)

如何定义pshaped_shapes以获取类似波纹管的结果?

tf.Tensor([0 1 2 3 4 5 6], shape=(7,), dtype=int32)
tf.Tensor([7 8 9 0 1 2 3], shape=(7,), dtype=int32)
tf.Tensor([4 5 6 7 8 9 0], shape=(7,), dtype=int32)
tf.Tensor([1 2 3 4 5 6 7], shape=(7,), dtype=int32)
tf.Tensor([8 9 0 0 0 0 0], shape=(7,), dtype=int32)

1 个答案:

答案 0 :(得分:0)

我用batch(7)解决了这个问题。

dataset2 = dataset.repeat(3).batch(7).padded_batch(7, padded_shapes=([None]))

输出

tf.Tensor(
[[0 1 2 3 4 5 6]
 [7 8 9 0 1 2 3]
 [4 5 6 7 8 9 0]
 [1 2 3 4 5 6 7]
 [8 9 0 0 0 0 0]], shape=(5, 7), dtype=int32)