标签: tensorflow memory machine-translation
我正在开发一个多语言机器翻译设置,其中包含多个编码器/解码器(每种语言都有一个),在火车时刻,我将批量从单个源提供给单个目标,然后在编码器之间切换和tf.case的解码器,具体取决于批次中给出的lang_src和lang_tgt。 我的问题是当我使用5种语言时,我达到了GPU的12GB内存限制 我不确定tensorflow是如何工作的,但我认为它为每个分支的激活或渐变分配GPU内存,但这在我的情况下是不必要的,因为对于任何批处理只有一条路径。 有没有办法优化它?
tf.case
lang_src
lang_tgt