如何在Tensorflow中加快嵌套for循环?

时间:2019-06-11 02:10:24

标签: python python-3.x tensorflow

当通过所有可能的组合连接张量时,我遇到一个问题。

假设我有三个块张量Z_priorZ_currentZ_post,尺寸分别为(n1, l)(n2, l)(n3, l)。我想使用

Z_combined = tf.concat([Z_prior, Z_current, Z_post], axis=1)

尺寸为(n1+n2+n3, l),以执行进一步的计算。

困难之处在于,Z_priorZ_currentZ_post中的每一个分别具有c1c2c3选择,在达到最终结果之前,我应该将c1*c2*c3的输出与不同的输入Z_combined进行比较。其他所有输入都是固定的。

目前,我只能想到的方法如下: 首先,我创建了三个列表对象,用于保存所有可能的组合。

Z_priors = [Z_prior_1, ..., Z_prior_c1]
Z_currents = [Z_current_1, ..., Z_current_c2]
Z_posts = [Z_post_1, ... , Z_post_c3]

然后,我使用嵌套的for循环来计算每个特定串联的结果,并将它们保存在一起。

results = tf.zeros([1], tf.float32)
for prior_index in range(c1):
    for current_index in range(c2):
        for post_index in range(c3):
            Z_combined = tf.concat([Z_priors[prior_index],
                                    Z_currents[current_index],
                                    Z_posts[post_index]], axis=1)
            result = fn(Z_combined, ...)
            results = tf.concat([results, [result]], axis=0)
results = results[1:]

(请注意,这里的结果是标量,所以我的结果应该是尺寸为(c1*c2*c3,)的1维数组。

但是一旦我用sess.run()对其进行了评估,那就太慢了。

我想知道在这种情况下是否有任何方法可以加快计算速度?

谢谢。

0 个答案:

没有答案