我试图找到最有效的张量流方式,以对大型张量(> 200行和列)执行依赖于路径的更新。
解决方案必须具有差异性(并可能与xla兼容)
我目前正在使用tf.unstack,检查for循环中的每个张量,并使用tf.where过滤出所需的条件。这很慢,并且导致许多张量运算
Bt = tf.ones([256])
Bt_n = tf.random_normal([200,256]) # would actually be calculated elsewhere
Mr = tf.random_normal([200,256])
Mp = tf.random_normal([200,256])
total = [Bt]
for mr, mp, n_Bt in zip(tf.unstack(Mr),
tf.unstack(Mp),
tf.unstack(Bt_n)):
Bt = tf.where(tf.logical_or(Bt <= mr, Bt >= mp), n_Bt, Bt)
total.append(Bt)
final = tf.concat(total, axis=0)
只需寻找最有效的方法(需要进行最少操作)即可。
谢谢。
答案 0 :(得分:0)
找到了答案-我需要使用tf.scan
即。
tf.scan(lambda a, x: tf.where(tf.logical_or(a <= x[0], a >= x[1]), x[2], a) , [Mr,Mp,Bt_n], initializer = Bt)