用TFrecords训练逐渐变慢

时间:2016-12-02 19:53:29

标签: python tensorflow deep-learning tensorflow-serving

我正在尝试使用TFrecord文件来训练tensorflow中的网络。问题是它开始运行良好,但一段时间后,它变得非常慢。在某段时间内,即使GPU利用率也达到0%。 我已经测量了迭代之间的时间,并且它显然在增加。 我已经读过这可能是因为在训练循环中向图形添加操作,以及可以使用graph.finalize()来解决这个问题。

我的代码是这样的:

    self.inputMR_,self.CT_GT_ = read_and_decode_single_example("data.tfrecords")

    self.inputMR, self.CT_GT = tf.train.shuffle_batch([self.inputMR_, self.CT_GT_], batch_size=self.batch_size, num_threads=2,
        capacity=500*self.batch_size,min_after_dequeue=2000)

    batch_size_tf = tf.shape(self.inputMR)[0]  #variable batchsize so we can test here
    self.train_phase = tf.placeholder(tf.bool, name='phase_train')
    self.G = self.Network(self.inputMR,batch_size_tf)# create the network
    self.g_loss=lp_loss(self.G, self.CT_GT, self.l_num, batch_size_tf)
    print 'learning rate ',self.learning_rate
    self.g_optim = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.g_loss)
    self.saver = tf.train.Saver()

然后我有一个看起来像这样的训练阶段:

def train(self, config):
    init=tf.initialize_all_variables()
    with tf.Session() as sess:
        sess.run(init)

        coord = tf.train.Coordinator()
        threads=tf.train.start_queue_runners(sess=sess, coord=coord)
        sess.graph.finalize()# **WHERE SHOULD I PUT THIS?**

        try:

            while not coord.should_stop():
                _,loss_eval = sess.run([self.g_optim, self.g_loss],feed_dict={self.train_phase: True})
               .....

        except:
            e = sys.exc_info()[0]

            print "Exception !!!", e
        finally:
            coord.request_stop()

        coord.join(threads)
        sess.close()

当我添加grapgh.finalize时,有一个例外:type' exceptions.RuntimeError' 任何人都可以向我解释一下,在训练过程中使用TFrecord文件的正确方法是什么,以及如何使用graph.finalize()而不在QueueRunner执行中进行干扰?

完整错误是:

  File "main.py", line 37, in <module>
    tf.app.run()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 30, in run
    sys.exit(main(sys.argv[:1] + flags_passthrough))
  File "main.py", line 35, in main
    gen_model.train(FLAGS)
  File "/home/dongnie/Desktop/gan/TF_record_MR_CT/model.py", line 143, in train
    self.global_step.assign(it).eval() # set and update(eval) global_step with index, i
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variables.py", line 505, in assign
    return state_ops.assign(self._variable, value, use_locking=use_locking)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_state_ops.py", line 45, in assign
    use_locking=use_locking, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 490, in apply_op
    preferred_dtype=default_dtype)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 657, in convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/constant_op.py", line 180, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/constant_op.py", line 167, in constant
    attrs={"value": tensor_value, "dtype": dtype_value}, name=name).outputs[0]
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2337, in create_op
    self._check_not_finalized()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2078, in _check_not_finalized
    raise RuntimeError("Graph is finalized and cannot be modified.")
RuntimeError: Graph is finalized and cannot be modified.

0 个答案:

没有答案
相关问题