我有一个tf.data.Dataset
个实例,其中包含3个不同的功能
label
这是一个标量sequence_feature
这是一系列标量seq_of_seqs_feature
这是一系列序列特征我正在尝试使用tf.data.Dataset.padded_batch()
生成填充数据作为我模型的输入 - 我希望以不同的方式填充每个功能。
批处理示例:
[{'label': 24,
'sequence_feature': [1, 2],
'seq_of_seqs_feature': [[11.1, 22.2],
[33.3, 44.4]]},
{'label': 32,
'sequence_feature': [3, 4, 5],
'seq_of_seqs_feature': [[55.55, 66.66]]}]
预期产出:
[{'label': 24,
'sequence_feature': [1, 2, 0],
'seq_of_seqs_feature': [[11.1, 22.2],
[33.3, 44.4]]},
{'label': 32,
'sequence_feature': [3, 4, 5],
'seq_of_seqs_feature': [[55.55, 66.66],
0.0, 0.0 ]}]
正如您所看到的,label
功能不应填充,sequence_feature
和seq_of_seqs_feature
应填充给定批次中相应的最长条目。
答案 0 :(得分:9)
tf.data.Dataset.padded_batch()
方法允许您为生成的批次的每个组件(要素)指定padded_shapes
。例如,如果您的输入数据集名为ds
:
padded_ds = ds.padded_batch(
BATCH_SIZE,
padded_shapes={
'label': [], # Scalar elements, no padding.
'sequence_feature': [None], # Vector elements, padded to longest.
'seq_of_seqs_feature': [None, None], # Matrix elements, padded to longest
}) # in each dimension.
请注意,padded_shapes
参数与输入数据集的元素具有相同的结构,因此在这种情况下,它需要一个包含与您的功能名称匹配的键的字典。