我正在Tensorflow中构建自定义优化器。这是基于梯度的优化器。我下面有训练虹膜花分类器的工作代码。但是,它运行非常慢。需要45秒才能完成。当我使用Tensorflow随附的优化器时,需要13秒。
我不确定为什么会这么慢或如何加快速度。我相信根本原因是session.run
调用。有谁知道为什么这段代码运行这么慢?我可以采取哪些步骤来提高速度?
main.py
optimizer = GradientDescentOptimizer(0.5, loss, sess)
while True:
step += 1
optimizer.minimize(session=sess, feed_dict={x:X_train, y:y_train})
gradient_descent.py
@tf_export("train.GradientDescentOptimizer")
class GradientDescentOptimizer(Optimizer):
...
def __init__(self, loss, session)
self.cg = self.compute_gradients(loss, tf.trainable_variables())
self.ag = self.apply_gradients(self.cg)
self.reset_gradients(session)
def reset_gradients(self, sess):
self.gradients = [tf.zeros(g[1].get_shape()).eval(session=sess) for g in self.cg]
def minimize(self, session=None, feed_dict=None):
feed = feed_dict
for i in range(len(self.cg)):
self.gradients[i] = session.run(self.cg[i][0], feed_dict=feed)
feed_dict_grads = {}
for i, grad_var in enumerate(self.cg):
feed_dict_grads[grad_var[0]] = self.gradients[i]
session.run(self.ag, feed_dict=feed_dict_grads)