以seq2seq教程为例,假设我们有[[5,5],(10,10)]和批量大小为16的桶: model_with_buckets用于构建模型。对于model_with_buckets(其为encoder_inputs)的输入,它是来自一个桶的批次,例如。大小是5 * 16 但是,即使它的大小与存储桶大小不同,也有代码将该批次运行到所有存储桶
# tensorflow/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
# in def model_with_buckets()
# this will run twice: seq2seq(encoder_inputs[:5],...) and seq2seq(encoder_inputs[:10],...)
# but encoder_inputs only belongs to bucket (5,5) and with size 5*16
for j, bucket in enumerate(buckets):
...
bucket_outputs, _ = seq2seq(encoder_inputs[:bucket[0]],
decoder_inputs[:bucket[1]])
输出时,只会使用encoder_inputs所属的存储桶丢失。
# models/tutorials/rnn/translate/seq2seq_model.py
# in def step()
output_feed = [self.updates[bucket_id], # Update Op that does SGD.
self.gradient_norms[bucket_id], # Gradient norm.
self.losses[bucket_id]] # Loss for this batch.
所以在我看来,model_with_buckets正在做不必要的工作,将encoder_input提供给它不属于的其他存储桶。这样做的目的是什么?
答案 0 :(得分:0)
代码处于图形组合阶段。这就是重点。 与会话链接时,step()函数将选择特定的存储桶。