我正在使用Tensorflow构建深度学习模型。在训练之前,我做了一些计算,比如反向传播。但它只需要计算一次。下面是我的伪代码:
class residual_net()
def pseudo_bp(self):
# do something...
self.bp = ...
def build_net(self):
# build a residual_network....
# utilize the variable in pseudo_bp
rn.output = func(self.bp)
def run():
rn = residual_net()
rn.pseudo_bp()
rn.deep_residual_network()
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
for i in range(1000):
err = tf.reduce_mean(rn.output, labels)
train = tf.train.GradientDescentOptimizer(learning_rate).minimize(err)
sess.run(train, feed_dict=train_feed_dict)
我想知道pseudo_bp
是否会在每次迭代中运行?如果是,我怎么能让它只运行一次?提前谢谢!
编辑: 最新的错误:
Traceback (most recent call last):
File "run.py", line 124, in <module>
sess.run(pseudo_bp, feed_dict=feed_dict)
File "/Users/yobichi/bigdata/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 717, in run
run_metadata_ptr)
File "/Users/yobichi/bigdata/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 902, in _run
fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string)
File "/Users/yobichi/bigdata/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 358, in __init__
self._fetch_mapper = _FetchMapper.for_fetch(fetches)
File "/Users/yobichi/bigdata/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 178, in for_fetch
(fetch, type(fetch)))
TypeError: Fetch argument None has invalid type <class 'NoneType'>
你有什么想法吗?
答案 0 :(得分:0)
在TensorFlow中,您首先构建一个tf.Graph
。该图由变量,操作和占位符组成。然后开始tf.Session()
,您可以在其中执行操作并更新变量。
在这种情况下,我认为psuedo_bp
最终需要你计算一些操作(如tf.matmul
)。 sess
就像一个指针,只要您运行sess.run(op)
,它就会执行各种tf.Operation
。您提供了一些输入来填充占位符(feed_dict
)。
因此,您只会在for循环的第一次迭代中执行sess.run(op)
。这是生成的代码 -
class residual_net()
def pseudo_bp(self):
# do something...
return op
def build_net(self):
# build a residual_network....
rn.output = sth
def run():
rn = residual_net()
operation = rn.pseudo_bp()
rn.build_net()
err = tf.reduce_mean(rn.output, labels)
train = tf.train.GradientDescentOptimizer(learning_rate).minimize(err)
# Graph has been built completely. Begin tf.Session()
sess = tf.Session()
sess.run(tf.initialize_all_variables())
for i in range(1000):
# Carry out the training in each iteration
# Note that train is an operation here
sess.run(train, feed_dict=feed_dict)
if i == 0:
# Execute `operation` for the first iteration
result = sess.run(operation, feed_dict=feed_dict)