我正在尝试从在线工作示例中学习张量流,但是遇到了一个我确实想知道它如何工作的示例。谁能解释张量流这个特定功能背后的数学原理,以及[ns]如何从布尔数据类型中获取其值。
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
Y, X = np.mgrid[-2.3:2.3:0.005, -5:5:0.005]
Z = X+1j*Y
c = tf.constant(Z, np.complex64)#.astype(np.complex64))
zs = tf.Variable(c)
ns = tf.Variable(tf.zeros_like(c, tf.float32))
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
zs_ = zs*zs + c
not_diverged = tf.abs(zs_) > 4
step = tf.group(zs.assign(zs_),
ns.assign_add(tf.cast(not_diverged, tf.float32)))
nx = tf.reduce_sum(ns)
zx = tf.reduce_sum(zs_)
cx = tf.reduce_sum(c)
zf = tf.reduce_all(not_diverged)
for i in range(200):
step.run()
print(sess.run([nx,zx,cx,zf]))
plt.imshow(ns.eval())
plt.show()
答案 0 :(得分:0)
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# this defines the complex plane
Y, X = np.mgrid[-2.3:2.3:0.005, -5:5:0.005]
Z = X+1j*Y
c = tf.constant(Z, np.complex64)
# tensors are immutable in tensorflow,
# but variabels arent, so use variable
# to update values later on
zs = tf.Variable(c)
# ns will keep count of what has diverged
ns = tf.Variable(tf.zeros_like(c, tf.float32))
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
# mandlebrot set M is defined as
# c \in M \iff |P_c^n(0)| <= 2 \iff abs(P_c^n(0)) <= 4
# where P_c(z) = z^2 + c
# the variable name is confusing, as it is actually
# the opposite, I renamed it below
zs_ = zs*zs + c
diverged = tf.abs(zs_) > 4
# ns gets its value as a bool casted to a float
# is given by True \mapsto 1., False \mapsto 0.
# the assign add just says, add tf.cast(diverged, tf.float32)
# to the variabel ns, and assign that value to the variable
step = tf.group(
zs.assign(zs_),
ns.assign_add(tf.cast(diverged, tf.float32)))
# here we iterate n to whatever we like
# each time we are moving further along the
# sequence P^n_c(0), which must be bounded
# in a disk of radius 2 to be in M
for i in range(200):
step.run()
# anywhere with value > 0 in the plot is not in the Mandlebrot set
# anywhere with value = 0 MIGHT be in the Mandlebrot set
# we don't know for sure if it is in the set,
# because we can only ever take n to be some
# finite number. But to be in the Mandlebrot set, it has
# to be bounded for all n!
plt.imshow(ns.eval())
plt.show()