当通过所有可能的组合连接张量时,我遇到一个问题。
假设我有三个块张量Z_prior
,Z_current
和Z_post
,尺寸分别为(n1, l)
,(n2, l)
和(n3, l)
。我想使用
Z_combined = tf.concat([Z_prior, Z_current, Z_post], axis=1)
尺寸为(n1+n2+n3, l)
,以执行进一步的计算。
困难之处在于,Z_prior
,Z_current
和Z_post
中的每一个分别具有c1
,c2
和c3
选择,在达到最终结果之前,我应该将c1*c2*c3
的输出与不同的输入Z_combined
进行比较。其他所有输入都是固定的。
目前,我只能想到的方法如下: 首先,我创建了三个列表对象,用于保存所有可能的组合。
Z_priors = [Z_prior_1, ..., Z_prior_c1]
Z_currents = [Z_current_1, ..., Z_current_c2]
Z_posts = [Z_post_1, ... , Z_post_c3]
然后,我使用嵌套的for循环来计算每个特定串联的结果,并将它们保存在一起。
results = tf.zeros([1], tf.float32)
for prior_index in range(c1):
for current_index in range(c2):
for post_index in range(c3):
Z_combined = tf.concat([Z_priors[prior_index],
Z_currents[current_index],
Z_posts[post_index]], axis=1)
result = fn(Z_combined, ...)
results = tf.concat([results, [result]], axis=0)
results = results[1:]
(请注意,这里的结果是标量,所以我的结果应该是尺寸为(c1*c2*c3,)
的1维数组。
但是一旦我用sess.run()
对其进行了评估,那就太慢了。
我想知道在这种情况下是否有任何方法可以加快计算速度?
谢谢。