我正在关注如何使用tf.scan
this tutorial,我写了一个最小的工作示例(请参阅下面的代码)。但是每次调用函数Model._step()
时,是不是都在创建计算图的另一个副本?如果没有,为什么不呢?
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' # to avoid TF suggesting SSE4.2, AVX etc...
class Model():
def __init__(self):
self._inputs = tf.placeholder(shape=[None], dtype=tf.float32)
self._predictions = self._compute_predictions()
def _step(self, old_state, new_input):
# ---- In here I will write a much more complex graph ----
return old_state + new_input
def _compute_predictions(self):
return tf.scan(self._step, self._inputs, initializer = tf.Variable(0.0))
@property
def predictions(self):
return self._predictions
@property
def inputs(self):
return self._inputs
def test(sess, model):
sess.run(tf.global_variables_initializer())
print(sess.run(model.predictions, {model.inputs: [1.0, 2.0, 3.0, 4.0]}))
test(tf.Session(), Model())
我问,因为这当然是一个最小的例子,在我的情况下,我需要一个更复杂的图形。
答案 0 :(得分:1)
Model._step()
方法只会在每个构造的Model
对象中调用一次。 tf.scan()
函数与它包装的tf.while_loop()
函数一样,只调用它们给定的函数一次来构建一个带有循环的图形,然后每次迭代都会使用相同的图形循环。
(请注意,如果您构建了许多Model
个对象,那么您最终会得到与您拥有Model
个对象相同数量的图表副本。)