我正在深入研究Tensorflows源代码,以便修复我的翻译模型中的输入形状问题。该模型创建了一个seq2seq模型,该模型再次调用rnn model。在第226行,我的RNNCell被调用,但发送的输入与我单元格中调用方法捕获的输入不同。我错过了什么吗? (注意call_cell()是一个lambda函数)
rnn.py中的 print(_input)
给出以下输出:
Tensor(“encoder0:0”,shape =(8,8),dtype = int32)
而RNNCell中的print(inputs)
给出:
Tensor(“model_with_buckets / embedding_attention_seq2seq / RNN / EmbeddingWrapper / embedding_lookup:0”,shape =(64,1000),dtype = float32,device = / device:CPU:0)
正如您所看到的,输入不一样,但我看不出还有什么。在rnn.py文件中调用call_cell时,我的代码会中断。张量是否受到我看不到的某些操作的影响,或者只是RNNCell收到的另一个张量?