根据输入张量流的序列长度动态更改批次大小

时间:2020-10-07 15:54:55

标签: python-3.x tensorflow2.0 tensorflow-datasets

我得到了形状可变长度(seq_len)的序列。我想创建批次,以使每个批次的总序列长度(seq_len1 + seq_len2 + ...)不超过某个值。如何在Tensorflow中使用数据集api创建它?

例如在改组序列之后...

x1 = [1,2,3] (seq_len = 3)
x2 = [3,4,5,6,7] (seq_len = 5)
x3 = [4,5,6,7,8,9,5,8] (seq_len = 8)
x4 = [1,7,3] (seq_len = 3)
x5 = [6,7,8] (seq_len = 3)
x6 = [5,6,7,8] (seq_len = 4)
.
.

batch_size len = 10 (lets say)

batch1 = [x1,x2] (8<10)
batch2 = [x3]  (8<10)
batch3 = [x4,x5,x6] (10=10)

谢谢

0 个答案:

没有答案