Tensorflow:如何使用新的数据API

时间:2017-11-28 10:46:07

标签: api tensorflow bucket

我正在尝试按照他们的长度对我的训练样例进行分组:https://www.tensorflow.org/versions/r0.12/api_docs/python/contrib.training/bucketing

但我想使用新的Data API。所以我想知道有没有办法做到这一点。

这是我的代码:

import tensorflow as tf

vocabulary = ["This", "is", "my", "first", "example",
              "the", "second", "one","How", "to", "bucket",
              "examples", "using", "new", "Data", "API"]

data = ["This is my first example",
        "How to bucket my examples using the new Data API",
        "This is the second one",
        "How to bucket my examples using the new Data API"]

BATCH_SIZE = 2

lookup_table = tf.contrib.lookup.index_table_from_tensor(vocabulary)
dataset = tf.data.Dataset.from_tensor_slices(data)


def tokenize(x):
    words = tf.string_split([x], " ").values
    return words


def lookup(x):
    ids = lookup_table.lookup(x)
    return ids


bucket_boundaries = [5, 10]


def bucketing(x):
    return tf.contrib.training.bucket_by_sequence_length(
        input_length=10,
        tensors=[x],
        batch_size=1,
        bucket_boundaries=bucket_boundaries,
        dynamic_pad=True
    )

# dataset = (dataset
#            .map(tokenize)
#            .map(lookup)
#            # .padded_batch(BATCH_SIZE, padded_shapes=[?])
#            )

dataset = (dataset
           .map(tokenize)
           .map(lookup)
           .map(bucketing)
           )

iterator = dataset.make_initializable_iterator()
next_batch = iterator.get_next()

init_op = tf.group(tf.global_variables_initializer(),
                   tf.tables_initializer(),
                   iterator.initializer)

sess = tf.Session()
sess.run(init_op)

for i in range(len(data)):
    batch = sess.run(next_batch)
    print(batch)

预期的输出应该是这样的:

  

[[0 1 2 3 4],[0 1 5 6 7]]

     

[[8 9 10 2 11 12 5 13 14 15],[8 9 10 2 11 12 5 13 14 15]]

上面的代码抛出OutOfRangeError。

  

OutOfRangeError(参见上面的回溯):序列结束

0 个答案:

没有答案