我正在尝试按照他们的长度对我的训练样例进行分组:https://www.tensorflow.org/versions/r0.12/api_docs/python/contrib.training/bucketing
但我想使用新的Data API。所以我想知道有没有办法做到这一点。
这是我的代码:
import tensorflow as tf
vocabulary = ["This", "is", "my", "first", "example",
"the", "second", "one","How", "to", "bucket",
"examples", "using", "new", "Data", "API"]
data = ["This is my first example",
"How to bucket my examples using the new Data API",
"This is the second one",
"How to bucket my examples using the new Data API"]
BATCH_SIZE = 2
lookup_table = tf.contrib.lookup.index_table_from_tensor(vocabulary)
dataset = tf.data.Dataset.from_tensor_slices(data)
def tokenize(x):
words = tf.string_split([x], " ").values
return words
def lookup(x):
ids = lookup_table.lookup(x)
return ids
bucket_boundaries = [5, 10]
def bucketing(x):
return tf.contrib.training.bucket_by_sequence_length(
input_length=10,
tensors=[x],
batch_size=1,
bucket_boundaries=bucket_boundaries,
dynamic_pad=True
)
# dataset = (dataset
# .map(tokenize)
# .map(lookup)
# # .padded_batch(BATCH_SIZE, padded_shapes=[?])
# )
dataset = (dataset
.map(tokenize)
.map(lookup)
.map(bucketing)
)
iterator = dataset.make_initializable_iterator()
next_batch = iterator.get_next()
init_op = tf.group(tf.global_variables_initializer(),
tf.tables_initializer(),
iterator.initializer)
sess = tf.Session()
sess.run(init_op)
for i in range(len(data)):
batch = sess.run(next_batch)
print(batch)
预期的输出应该是这样的:
[[0 1 2 3 4],[0 1 5 6 7]]
[[8 9 10 2 11 12 5 13 14 15],[8 9 10 2 11 12 5 13 14 15]]
上面的代码抛出OutOfRangeError。
OutOfRangeError(参见上面的回溯):序列结束