我正在NER
中使用tensorflow
之类的序列标记,并决定尝试使用tf.data
来查看是否可以使用我的模型改进IO性能。
目前我正在使用TFRecordWriter
预处理并保存我的训练/验证数据,这是tf.train.SequenceExample()
序列化为字符串。然后我用tf.data.TFRecordDataset
加载它,解析/ shuffle / padded_batch然后继续训练,这很好。
问题:
dataset
的情况下制作serializing
并将SeuquenceExamples保存到tfrecord
文件?tf.data.Dataset.from_tensor_slices()
,但在这种情况下似乎不适合,因为输入是不填充的不同长度的序列。答案 0 :(得分:2)
在这种情况下可以使用tf.data.Dataset.from_generator()
。例如,让我们说您的示例看起来像以下非常简单的数据,有两个功能(其中第二个代表顺序数据):
examples = [("foo", [1, 2, 3, 4, 5]),
("bar", [6, 7]),
("baz", [8, 9, 10])]
您可以使用以下代码将其转换为tf.data.Dataset
:
def example_generator():
for string_feature, sequence_feature in examples:
yield string_feature, sequence_feature
dataset = tf.data.Dataset.from_generator(
example_generator,
output_types=(tf.string, tf.int32),
output_shapes=([], [None]), # A scalar and a variable-length vector.
)