tensorflow网站上有an example,其中TF用于计算Mandlebrot集。这是相关的代码段:
# Compute the new values of z: z^2 + x
zs_ = zs*zs + xs
# Have we diverged with this new value?
not_diverged = tf.abs(zs_) < 4
# Operation to update the zs and the iteration count.
#
# Note: We keep computing zs after they diverge! This
# is very wasteful! There are better, if a little
# less simple, ways to do this.
#
step = tf.group(
zs.assign(zs_),
ns.assign_add(tf.cast(not_diverged, tf.float32))
)
相关的引言是:“有更好的方法,即使不太简单,也可以做到这一点。”谁知道有什么更好的方法吗?
我正在搞乱TF中的光线跟踪,并且遇到90%像素会聚的情况,但是我一直在重新计算它们,因为我不知道如何更新张量中的全部子集而不会牺牲在任何地方使用向量运算的速度优势。
答案 0 :(得分:0)
我想出了一种方法。
关键是使用tf.where
查找尚未发散的像素,tf.gather_nd
将这些像素拉入较小的阵列,对这些特定像素执行更新步骤,然后使用{ {1}}(或另一个“分散”函数)将稀疏更新应用于表示状态的某些变量。
在我的用例中,这节省了大量时间。