TF map_fn或while_loop用于不同形状的张量列表

时间:2018-04-08 05:47:17

标签: python tensorflow control-flow

我想处理不同形状的张量序列(列表)并输出另一张张量列表。在每个时间戳上考虑具有不同隐藏状态大小的RNN。像

这样的东西

输入:[tf.ones((1,2,2)),tf.ones((2,2,3)),tf.ones((3,2,1))]

输出:[tf.zeros((1,2,4)),tf.zeros((4,2,6)),tf.zeros((6,2,1))]

我无法将输入(或输出)堆叠到单个张量中,因为它们都具有不同的形状,因此我不能将tf.map_fn用于任务。现在,我使用python for循环,但它似乎不是最理想的。

我能做些什么更好的事情吗?

1 个答案:

答案 0 :(得分:2)

您可以使用tf.while_loop重复执行任意TensorFlow操作,直到出现某些停止条件。停止条件本身被指定为操作。

请注意,tf.while_loop应谨慎使用,因为默认情况下其迭代将并行运行。例如,如果循环体增加tf.Variable,则必须使用control dependencies来确保迭代按顺序运行。

但是,你提到你有一个Python循环的工作实现。如果可能,使用Python进行循环通常是最有效的解决方案。在Python中构建循环时,可以为循环中的每次迭代创建单独的ops。这让TensorFlow决定在图形构建中如何为每个操作分配计算资源。例如,如果事先知道迭代次数,则更容易预测内存需求和并行化可能性。

出于这个原因,tf.while_looptf.map_fn是最常用的,因为在图形构建时不知道停止条件。

如果存在固定但非常大量的迭代,您可能仍然希望使用tf.while_loop而不是Python循环,因为每个操作都有一个非常重要的内存成本。