预计算并存储Tensorflow图以获取损失函数

时间:2018-07-23 00:21:45

标签: python performance tensorflow rnn loss-function

我正在研究TensorFlow程序的序列,该程序使用RNN预测大小为(BATCH_SIZE x L x 11)的张量。我有一个损失函数,该函数采用(BATCH_SIZE x L x 11)张量并将其转换为表示3D坐标的张量,其大小为(BATCH_SIZE x L' x 3),其中L'L的函数。

不幸的是,当我实现此损失函数以对后一个张量(L')进行操作时,我的训练时间从9小时/纪元猛增到300小时/纪元。我相信问题出在每一批的每个计算图的创建上-这使一切变得非常缓慢。

是否可以通过某种方式预先计算部分计算图以节省时间?建筑和差异化似乎都将永远消失。 (请注意,图的这一部分仅存在于损失函数中,无需进行训练,因为我具有从L空间到L'空间的数学映射。)

非常感谢。对于如何在TF中加载预构建的子图或是否有更好的选择(例如,移至支持动态批处理形状的PyTorch这样的后端),我将不胜感激。

0 个答案:

没有答案