如何根据令牌数量对 tf 数据集进行批处理?

时间:2021-07-05 12:38:13

标签: python tensorflow tf.data.dataset

我正在尝试使用 tensorflow 复制名为“Attention is al you need”(https://arxiv.org/pdf/1706.03762.pdf)的论文的结果。大部分代码都在这里完成:https://www.tensorflow.org/text/tutorials/transformer 但有一些没有正确实现的小细节。其中之一是批处理操作。对于这种模型,批次是由令牌组成的数组。论文中指出,无论批次大小,批次都由 25k 个令牌组成。怎么办?

作为一个例子,让我们以这个列表列表 [[1, 2], [1], [1, 2, 3], [1], [1], [1]] 考虑令牌目标大小为 3每批令牌正确的解决方案是:[[[1, 2], [1]], [[1, 2, 3]], [[1], [1], [1]]]

0 个答案:

没有答案