为什么使用eval和session.run这么慢?

时间:2018-09-29 21:42:57

标签: python tensorflow machine-learning

我正在Tensorflow中构建自定义优化器。这是基于梯度的优化器。我下面有训练虹膜花分类器的工作代码。但是,它运行非常慢。需要45秒才能完成。当我使用Tensorflow随附的优化器时,需要13秒。

我不确定为什么会这么慢或如何加快速度。我相信根本原因是session.run调用。有谁知道为什么这段代码运行这么慢?我可以采取哪些步骤来提高速度?

main.py

    optimizer = GradientDescentOptimizer(0.5, loss, sess)
    while True:
        step += 1
        optimizer.minimize(session=sess, feed_dict={x:X_train, y:y_train})

gradient_descent.py

@tf_export("train.GradientDescentOptimizer")
class GradientDescentOptimizer(Optimizer):

  ...
  def __init__(self, loss, session)
    self.cg = self.compute_gradients(loss, tf.trainable_variables())
    self.ag = self.apply_gradients(self.cg)  

    self.reset_gradients(session)


  def reset_gradients(self, sess):                                                                                                                                              
     self.gradients = [tf.zeros(g[1].get_shape()).eval(session=sess) for g in self.cg]  

  def minimize(self, session=None, feed_dict=None):
    feed = feed_dict
    for i in range(len(self.cg)):
      self.gradients[i] = session.run(self.cg[i][0], feed_dict=feed)        

    feed_dict_grads = {}                                                                                                                                                    
    for i, grad_var in enumerate(self.cg):                                                                                                             
      feed_dict_grads[grad_var[0]] = self.gradients[i]

    session.run(self.ag, feed_dict=feed_dict_grads)

0 个答案:

没有答案