tf.data.Dataset追加元素

时间:2017-11-14 19:14:15

标签: tensorflow

给定具有以下结构的TFRecord:

context_features = {
    "labels_length": tf.FixedLenFeature(shape=[], dtype=tf.int64),
    "filename": tf.FixedLenFeature(shape=[], dtype=tf.string)
}
sequence_features = {
    "labels": tf.FixedLenSequenceFeature(shape=[], dtype=tf.int64)
}

我想在解析记录文件时在运行时向labels字段追加一个元素。为此数据集编写的迭代器也会创建批处理并用零填充它们,因此在添加填充值之前在列表末尾添加元素非常重要。例如,如果我们有三个条目:[1.2],[3,4,5],[6,7,8,9]并且要添加的元素是10,则填充的批处理应如下所示: [1, 2, 10, 0, 0], [3, 4, 5, 10, 0], [6, 7, 8, 9, 10]

你能推荐我一种方法吗? 谢谢

1 个答案:

答案 0 :(得分:3)

要执行此操作,您只需使用tf.concat()将元素添加到labels的{​​{1}}张量中。例如,要将tf.parse_single_sequence_example()附加到每个标签:

10

请注意,我不确定您如何在程序中使用def _parse_labels_function(example): context_features = { "labels_length": tf.FixedLenFeature(shape=[], dtype=tf.int64), "filename": tf.FixedLenFeature(shape=[], dtype=tf.string) } sequence_features = { "labels": tf.FixedLenSequenceFeature(shape=[], dtype=tf.int64) } context_parsed, sequence_parsed = tf.parse_single_sequence_example( serialized=example, context_features=context_features, sequence_features=sequence_features ) # Append `10` to each label sequence. labels = tf.concat([sequence_parsed["labels"], [10]], 0) return labels, context_parsed["labels_length"], context_parsed["filename"] dataset = tf.data.TFRecordDataset(label_record) dataset = dataset.map(_parse_labels_function) 功能,但您可能还想在返回之前添加一个功能。