我试图用map_fn替换for循环,因为后者似乎有助于提高循环效率。
问题是,如果map_fn中的fn调用get_variable()来创建一个新变量,那么如何在循环的其余部分将重用设置为True?或者get_variable()仅在map_fn中调用一次?
def fn(x):
y = tf.get_variable('y', [])
return x * x
squares = tf.map_fn(fn, np.array([1, 2, 3, 4 ,5 ,6]))
# Out: [array([ 1, 4, 9, 16, 25, 36])]
sess.run([squares])