我已经知道数据集为padded batch提供了padded_batch函数,如下所示:
padded_batch(
batch_size,
padded_shapes,
padding_values=None
)
我想将批处理中的每个维度填充到所有维度的最大长度,但上述功能无法实现此目的。
例如:
我的数据集中的元素如下:
([], [5,5,5,5,5])
([1], [6,6,6,6,6,6])
([2,2], [7,7,7,7,7,7,7])
([3,3,3], [8,8,8,8,8,8,8,8])
what I want after padding is
([0,0,0,0,0,0,0,0,0], [5,5,5,5,5,0,0,0])
([1,0,0,0,0,0,0,0,0], [6,6,6,6,6,6,0,0))
([2,2,0,0,0,0,0,0,0], [7,7,7,7,7,7,7,0])
([3,3,3,0,0,0,0,0,0], [8,8,8,8,8,8,8,8])
如果我应用这样的功能:
dataset = dataset.padded_batch(
4,
padded_shapes=(tf.TensorShape([None]),
tf.TensorShape([None]))
)
我得到的结果如下:
([0,0,0], [5,5,5,5,5,0,0,0])
([1,0,0], [6,6,6,6,6,6,0,0))
([2,2,0], [7,7,7,7,7,7,7,0])
([3,3,3], [8,8,8,8,8,8,8,8])
这不是我想要的,我很长时间都在努力解决这个问题,任何人都可以帮助我吗?