大张量的有效路径更新

时间:2019-06-11 05:30:52

标签: python tensorflow

我试图找到最有效的张量流方式,以对大型张量(> 200行和列)执行依赖于路径的更新。

解决方案必须具有差异性(并可能与xla兼容)

我目前正在使用tf.unstack,检查for循环中的每个张量,并使用tf.where过滤出所需的条件。这很慢,并且导致许多张量运算


Bt = tf.ones([256])
Bt_n = tf.random_normal([200,256]) # would actually be calculated elsewhere
Mr = tf.random_normal([200,256])
Mp = tf.random_normal([200,256])

total = [Bt]

for mr, mp, n_Bt in zip(tf.unstack(Mr), 
                      tf.unstack(Mp),                                                      
                      tf.unstack(Bt_n)):
    Bt = tf.where(tf.logical_or(Bt <= mr, Bt >= mp), n_Bt, Bt)
    total.append(Bt)

final = tf.concat(total, axis=0)

只需寻找最有效的方法(需要进行最少操作)即可。

谢谢。

1 个答案:

答案 0 :(得分:0)

找到了答案-我需要使用tf.scan

即。

tf.scan(lambda a, x: tf.where(tf.logical_or(a <= x[0], a >= x[1]), x[2], a) , [Mr,Mp,Bt_n], initializer = Bt)