seq2seq.py#model_with_buckets中不必要的操作

时间:2017-03-30 07:06:15

标签: tensorflow

以seq2seq教程为例,假设我们有[[5,5],(10,10)]和批量大小为16的桶: model_with_buckets用于构建模型。对于model_with_buckets(其为encoder_inputs)的输入,它是来自一个桶的批次,例如。大小是5 * 16 但是,即使它的大小与存储桶大小不同,也有代码将该批次运行到所有存储桶

# tensorflow/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
# in def model_with_buckets()
# this will run twice: seq2seq(encoder_inputs[:5],...) and seq2seq(encoder_inputs[:10],...)
# but encoder_inputs only belongs to bucket (5,5) and with size 5*16
for j, bucket in enumerate(buckets):
    ...
    bucket_outputs, _ = seq2seq(encoder_inputs[:bucket[0]],
                                decoder_inputs[:bucket[1]])

输出时,只会使用encoder_inputs所属的存储桶丢失。

# models/tutorials/rnn/translate/seq2seq_model.py
# in def step()
output_feed = [self.updates[bucket_id],  # Update Op that does SGD.
               self.gradient_norms[bucket_id],  # Gradient norm.
               self.losses[bucket_id]]  # Loss for this batch.

所以在我看来,model_with_buckets正在做不必要的工作,将encoder_input提供给它不属于的其他存储桶。这样做的目的是什么?

1 个答案:

答案 0 :(得分:0)

代码处于图形组合阶段。这就是重点。 与会话链接时,step()函数将选择特定的存储桶。