如何在每个维度填充的数据集中填充批次到所有维度的最大长度

时间:2017-12-23 12:49:21

标签: tensorflow

我已经知道数据集为padded batch提供了padded_batch函数,如下所示:

padded_batch(
    batch_size,
    padded_shapes,
    padding_values=None
)

我想将批处理中的每个维度填充到所有维度的最大长度,但上述功能无法实现此目的。

例如:

我的数据集中的元素如下:

([],      [5,5,5,5,5])

([1],     [6,6,6,6,6,6])

([2,2],   [7,7,7,7,7,7,7])

([3,3,3], [8,8,8,8,8,8,8,8])

what I want after padding is 

([0,0,0,0,0,0,0,0,0],      [5,5,5,5,5,0,0,0])

([1,0,0,0,0,0,0,0,0],      [6,6,6,6,6,6,0,0))

([2,2,0,0,0,0,0,0,0],      [7,7,7,7,7,7,7,0])

([3,3,3,0,0,0,0,0,0],      [8,8,8,8,8,8,8,8])

如果我应用这样的功能:

dataset = dataset.padded_batch(
            4, 
            padded_shapes=(tf.TensorShape([None]), 
            tf.TensorShape([None]))                            
          ) 

我得到的结果如下:

([0,0,0],      [5,5,5,5,5,0,0,0])

([1,0,0],      [6,6,6,6,6,6,0,0))

([2,2,0],      [7,7,7,7,7,7,7,0])

([3,3,3],      [8,8,8,8,8,8,8,8])

这不是我想要的,我很长时间都在努力解决这个问题,任何人都可以帮助我吗?

0 个答案:

没有答案