我正在为wmt-14数据集实施一个NMT模型,该数据集有超过12,000,000行。我把它转换成适合tf的二进制格式。我使用data_batcher.py
按tensorflow/models
来读取数据。但是当我调用NextBatch
函数时,它会挂在这行代码中:
buckets = self._bucket_input_queue.get()
相关代码如下:
class Batcher(object):
"""Batch reader with shuffling and bucketing support."""
def __init__(self, data_path, vocab, hps,
article_key, abstract_key, max_article_sentences,
max_abstract_sentences, bucketing=True, truncate_input=False):
"""Batcher constructor.
Args:
data_path: tf.Example filepattern.
vocab: Vocabulary.
hps: Seq2SeqAttention model hyperparameters.
article_key: article feature key in tf.Example.
abstract_key: abstract feature key in tf.Example.
max_article_sentences: Max number of sentences used from article.
max_abstract_sentences: Max number of sentences used from abstract.
bucketing: Whether bucket articles of similar length into the same batch.
truncate_input: Whether to truncate input that is too long. Alternative is
to discard such examples.
"""
self._data_path = data_path
self._vocab = vocab
self._hps = hps
self._article_key = article_key
self._abstract_key = abstract_key
self._max_article_sentences = max_article_sentences
self._max_abstract_sentences = max_abstract_sentences
self._bucketing = bucketing
self._truncate_input = truncate_input
self._input_queue = Queue.Queue(QUEUE_NUM_BATCH * self._hps.batch_size)
self._bucket_input_queue = Queue.Queue(QUEUE_NUM_BATCH)
self._input_threads = []
for _ in xrange(16):
self._input_threads.append(Thread(target=self._FillInputQueue))
self._input_threads[-1].daemon = True
self._input_threads[-1].start()
self._bucketing_threads = []
for _ in xrange(4):
self._bucketing_threads.append(Thread(target=self._FillBucketInputQueue))
self._bucketing_threads[-1].daemon = True
self._bucketing_threads[-1].start()
self._watch_thread = Thread(target=self._WatchThreads)
self._watch_thread.daemon = True
self._watch_thread.start()
def NextBatch(self, drop_prob):
"""Returns a batch of inputs for seq2seq attention model.
Returns:
enc_batch: A batch of encoder inputs [batch_size, hps.enc_timestamps].
dec_batch: A batch of decoder inputs [batch_size, hps.dec_timestamps].
target_batch: A batch of targets [batch_size, hps.dec_timestamps].
enc_input_len: encoder input lengths of the batch.
dec_input_len: decoder input lengths of the batch.
loss_weights: weights for loss function, 1 if not padded, 0 if padded.
origin_articles: original article words.
origin_abstracts: original abstract words.
"""
enc_batch = np.zeros(
(self._hps.batch_size, self._hps.enc_timesteps), dtype=np.int32)
enc_input_lens = np.zeros(
(self._hps.batch_size), dtype=np.int32)
dec_batch = np.zeros(
(self._hps.batch_size, self._hps.dec_timesteps), dtype=np.int32)
dec_dropped_batch = np.zeros(
(self._hps.batch_size, self._hps.dec_timesteps), dtype=np.int32)
dec_output_lens = np.zeros(
(self._hps.batch_size), dtype=np.int32)
target_batch = np.zeros(
(self._hps.batch_size, self._hps.dec_timesteps), dtype=np.int32)
loss_weights = np.zeros(
(self._hps.batch_size, self._hps.dec_timesteps), dtype=np.float32)
origin_articles = ['None'] * self._hps.batch_size
origin_abstracts = ['None'] * self._hps.batch_size
buckets = self._bucket_input_queue.get()
for i in xrange(self._hps.batch_size):
(enc_inputs, dec_inputs, targets, enc_input_len, dec_output_len,
article, abstract) = buckets[i]
drop_idx = np.random.choice(dec_output_len, int(drop_prob*dec_output_len))
origin_articles[i] = article
origin_abstracts[i] = abstract
enc_input_lens[i] = enc_input_len
dec_output_lens[i] = dec_output_len
enc_batch[i, :] = enc_inputs[:]
dec_dropped_batch[i, :] = dec_batch[i, :] = dec_inputs[:]
dec_dropped_batch[i, drop_idx] = self.unknown_id
target_batch[i, :] = targets[:]
for j in xrange(dec_output_len):
loss_weights[i][j] = 1
return (enc_batch, dec_batch, dec_dropped_batch, target_batch, enc_input_lens, dec_output_lens,
loss_weights, origin_articles, origin_abstracts)
def _FillBucketInputQueue(self):
"""Fill bucketed batches into the bucket_input_queue."""
while True:
inputs = []
for _ in xrange(self._hps.batch_size * BUCKET_CACHE_BATCH):
inputs.append(self._input_queue.get())
if self._bucketing:
inputs = sorted(inputs, key=lambda inp: inp.enc_len)
batches = []
for i in xrange(0, len(inputs), self._hps.batch_size):
batches.append(inputs[i:i+self._hps.batch_size])
shuffle(batches)
for b in batches:
self._bucket_input_queue.put(b)
我想知道为什么_bucket_input_queue
中的元素无法正常运行?我该如何调整呢?非常感谢提前!