给定具有以下结构的TFRecord:
context_features = {
"labels_length": tf.FixedLenFeature(shape=[], dtype=tf.int64),
"filename": tf.FixedLenFeature(shape=[], dtype=tf.string)
}
sequence_features = {
"labels": tf.FixedLenSequenceFeature(shape=[], dtype=tf.int64)
}
我想在解析记录文件时在运行时向labels
字段追加一个元素。为此数据集编写的迭代器也会创建批处理并用零填充它们,因此在添加填充值之前在列表末尾添加元素非常重要。例如,如果我们有三个条目:[1.2],[3,4,5],[6,7,8,9]并且要添加的元素是10,则填充的批处理应如下所示:
[1, 2, 10, 0, 0],
[3, 4, 5, 10, 0],
[6, 7, 8, 9, 10]
你能推荐我一种方法吗? 谢谢
答案 0 :(得分:3)
要执行此操作,您只需使用tf.concat()
将元素添加到labels
的{{1}}张量中。例如,要将tf.parse_single_sequence_example()
附加到每个标签:
10
请注意,我不确定您如何在程序中使用def _parse_labels_function(example):
context_features = {
"labels_length": tf.FixedLenFeature(shape=[], dtype=tf.int64),
"filename": tf.FixedLenFeature(shape=[], dtype=tf.string)
}
sequence_features = {
"labels": tf.FixedLenSequenceFeature(shape=[], dtype=tf.int64)
}
context_parsed, sequence_parsed = tf.parse_single_sequence_example(
serialized=example,
context_features=context_features,
sequence_features=sequence_features
)
# Append `10` to each label sequence.
labels = tf.concat([sequence_parsed["labels"], [10]], 0)
return labels, context_parsed["labels_length"], context_parsed["filename"]
dataset = tf.data.TFRecordDataset(label_record)
dataset = dataset.map(_parse_labels_function)
功能,但您可能还想在返回之前添加一个功能。