使子图在Tensorflow中只计算一次

时间:2016-11-22 04:42:35

标签: python tensorflow

我正在使用Tensorflow构建深度学习模型。在训练之前,我做了一些计算,比如反向传播。但它只需要计算一次。下面是我的伪代码:

class residual_net()
    def pseudo_bp(self):
        # do something...
        self.bp = ...

    def build_net(self):
        # build a residual_network....
        # utilize the variable in pseudo_bp
        rn.output = func(self.bp)

def run():
    rn = residual_net()
    rn.pseudo_bp()
    rn.deep_residual_network()
    sess = tf.InteractiveSession()
    sess.run(tf.initialize_all_variables())
    for i in range(1000):
        err = tf.reduce_mean(rn.output, labels)
        train = tf.train.GradientDescentOptimizer(learning_rate).minimize(err)
        sess.run(train, feed_dict=train_feed_dict)

我想知道pseudo_bp是否会在每次迭代中运行?如果是,我怎么能让它只运行一次?提前谢谢!

编辑: 最新的错误:

Traceback (most recent call last):
  File "run.py", line 124, in <module>
    sess.run(pseudo_bp, feed_dict=feed_dict)
  File "/Users/yobichi/bigdata/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 717, in run
    run_metadata_ptr)
  File "/Users/yobichi/bigdata/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 902, in _run
    fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string)
  File "/Users/yobichi/bigdata/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 358, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/Users/yobichi/bigdata/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 178, in for_fetch
    (fetch, type(fetch)))
TypeError: Fetch argument None has invalid type <class 'NoneType'>

你有什么想法吗?

1 个答案:

答案 0 :(得分:0)

在TensorFlow中,您首先构建一个tf.Graph。该图由变量,操作和占位符组成。然后开始tf.Session(),您可以在其中执行操作并更新变量。

在这种情况下,我认为psuedo_bp最终需要你计算一些操作(如tf.matmul)。 sess就像一个指针,只要您运行sess.run(op),它就会执行各种tf.Operation。您提供了一些输入来填充占位符(feed_dict)。

因此,您只会在for循环的第一次迭代中执行sess.run(op)。这是生成的代码 -

class residual_net()
    def pseudo_bp(self):
        # do something...
        return op

    def build_net(self):
        # build a residual_network....
        rn.output = sth

def run():
    rn = residual_net()
    operation = rn.pseudo_bp()
    rn.build_net()
    err = tf.reduce_mean(rn.output, labels)
    train = tf.train.GradientDescentOptimizer(learning_rate).minimize(err)
    # Graph has been built completely. Begin tf.Session()
    sess = tf.Session()
    sess.run(tf.initialize_all_variables())
    for i in range(1000):
        # Carry out the training in each iteration
        # Note that train is an operation here
        sess.run(train, feed_dict=feed_dict)
        if i == 0:
            # Execute `operation` for the first iteration
            result = sess.run(operation, feed_dict=feed_dict)