如何从TensorFlow Dataset对象创建分段张量?

时间:2017-09-10 06:58:26

标签: tensorflow

我使用以下代码创建了一个TextLineDataset对象:

dataset = TextLineDataset([text_path])

然后我想从这个数据集创建一个bucketed张量。我知道有一个名为bucket_by_sequence_length的API。我尝试通过调用dataset.make_one_shot_iterator()向迭代器提供此API,但它不起作用。我应该如何提供input_length的{​​{1}}和tensors个参数?

2 个答案:

答案 0 :(得分:1)

根据一些调查结果,我发现bucket_by_sequence_length旨在处理张量,这些张量可以排入Queue。但iterator的{​​{1}}是不同的。

然后我发现Dataset支持Dataset操作,可以用来生成分块数据集。

答案 1 :(得分:-3)

拥有数据集对象后,您可以使用以下代码通过API bucket_by_sequence_length生成批处理。

# This will be used by bucket_by_sequence_length to batch them according to their length.
def _element_length_fn(x, y=None):
    return array_ops.shape(x)[0]


# These are the upper length boundaries for the buckets.
# Based on these boundaries, the sentences will be shifted to different buckets.
boundaries = [upper_boundary_for_batch] # Here you will have to define the upper boundaries for different buckets. You can have as many boundaries as you want. But make sure that the upper boundary contains the maximum length of the sentence that is in your dataset.

# These defines the batch sizes for different buckets.
# I am keeping the batch_size for each bucket same, but this can be changed based on more analysis.
# As per the documentation - batch size per bucket. Length should be len(bucket_boundaries) + 1.
# https://www.tensorflow.org/api_docs/python/tf/data/experimental/bucket_by_sequence_length
batch_sizes = [batch_size] * (len(boundaries) + 1)

# Bucket_by_sequence_length returns a dataset transformation function that has to be applied using dataset.apply.
# Here the important parameter is pad_to_bucket_boundary. If this is set to true then, the sentences will be padded to
# the bucket boundaries provided. If set to False, it will pad the sentences to the maximum length found in the batch.
# Default value for padding is 0, so we do not need to supply anything extra here.
dataset = dataset.apply(tf.data.experimental.bucket_by_sequence_length(_element_length_fn, boundaries,
                                                                       batch_sizes,
                                                                       drop_remainder=True,
                                                                       pad_to_bucket_boundary=True))