我正在尝试在培训期间最大化GPU占用率。我有可变长度序列,我想密集包装到固定长度批次。基本上,我希望短序列后跟另一个序列,我希望将长序列拆分,以便它们在下一批中继续。例如:
// Say batch size is 2 and desired sequence length is 4
s1 = [a, b, c, d, e, f]
s2 = [x, y, z]
s3 = [l, m, n, o]
// Resulting batches:
b1 = [[a, b, c, d]
[x, y, z, l]]
b2 = [[e, f, _, _]
[m, n, o, _]]
在Tensorflow中有一种简单的方法吗?我的序列来自tf.TextLineReader
:
file_queue = tf.train.string_input_producer('./example_text')
reader = tf.TextLineReader()
key, sentence = reader.read(file_queue)
// convert string to int32 vector
sequence_tensor = to_sequence(sentence)
// what I wish I had:
batch = tf.fixed_length_batch_from_variable_length_sequences(
sequence_tensor, batch_size, fixed_length)
提前感谢您提出任何建议。
答案 0 :(得分:0)
好的,我有一个几乎我所希望的实际例子。以下代码以我想要的方式生成批处理,但它需要使用占位符将数据传入和传出TF会话。我希望能够完全从TF图中构建这些批次。
希望我很傻,而且有人可以指出一些明显的解决方案。请原谅camelCase。
import tensorflow as tf
def buildBatch(seqLength, batchSize):
def lineToSequence(line):
line = tf.expand_dims(line, axis=0)
line = tf.sparse_tensor_to_dense(tf.string_split(line), '_')
line = tf.concat([line, [['<GO>']]], 1)
return line
data = tf.contrib.data.TextLineDataset(['./exampleFile.txt'])
data = data.map(lambda line: lineToSequence(line))
iterator = data.make_initializable_iterator()
# Grab lines from the file until the the sequence length is met and shave off any extra
def getFixedLengthSequence(start):
c = lambda s: tf.shape(s)[1] < seqLength # while sequence is is too short
b = lambda s: tf.concat([s, iterator.get_next()], 1) # concatenate the next line
sentences = tf.while_loop(c, b, [start], back_prop=False, parallel_iterations=1,
shape_invariants=[tf.TensorShape([1, None])])
clippedToLength = tf.expand_dims(sentences[0, :seqLength], axis=0)
leftover = tf.expand_dims(sentences[0, seqLength:], axis=0)
return clippedToLength, leftover
# Placeholders pass in the start of each sequence (which are saved from the last batch)
startOfThisBatch = [tf.placeholder(tf.string, shape=[1,None]) for i in range(batchSize)]
# Capture what is leftover from each sequence so it can be passed in to start the next batch
startOfNextBatch = [tf.TensorArray(tf.string, size=1) for i in range(batchSize)]
# Build the batch
thisBatch = []
for i, seqStart in enumerate(startOfThisBatch):
seq, leftover = getFixedLengthSequence(seqStart)
thisBatch.append(seq)
startOfNextBatch[i] = startOfNextBatch[i].write(0, leftover)
thisBatch = tf.concat(thisBatch, axis=0)
startOfNextBatch = [b.read(0) for b in startOfNextBatch]
return thisBatch, startOfThisBatch, startOfNextBatch, iterator.initializer
def printBatch():
sequenceLength = 10
batchSize = 3
batch, startOfThisBatch, startOfNextBatch, iteratorInit = buildBatch(sequenceLength, batchSize)
# The very first batch starts with <GO> tokens
batchStarts = [[['<GO>']]]*batchSize
sv = tf.train.Supervisor()
with sv.managed_session() as sess:
sess.run(iteratorInit)
for b in range(4):
# Populate feed dict with the beginning of each sequence in the batch
feed = {}
for i in range(batchSize):
feed[startOfThisBatch[i]] = batchStarts[i]
# Call TF to get this batch and the starting sequences of the next batch
out, batchStarts = sess.run([batch, startOfNextBatch], feed_dict=feed)
print 'Batch', b, ':'
for seq in out:
print " ".join(seq)
print
printBatch()
结果:
Batch 0 :
<GO> A spokesman said the company has been affected by
<GO> Having a little flexibility on that issue would go
<GO> Long before the advent of e-commerce , Wal-Mart 's
Batch 1 :
the credit crunch in the United States . <GO> Abu
a long way to putting together a final package .
founder Sam Walton set out his vision for a successful
Batch 2 :
Dhabi is going ahead to build solar city and no
<GO> Her back was torn open , her liver was
retail operation : " We let folks know we 're
Batch 3 :
pollution city . <GO> Now it has 175 staging centers
ruptured , one of her lungs had collapsed and the
interested in them and that they 're vital to us--
请注意,每个句子在下一批中继续。使用的示例文本文件来自1-billion word benchmark dataset,每行包含一个句子。