每次迭代get_weights都很慢

时间:2019-10-26 01:54:08

标签: tensorflow keras python-3.7 tensorflow2.0

我正在从专用网络计算梯度并将其应用于另一个主网络。然后,我将主服务器的权重复制到私有服务器上(听起来有些多余,但请耐心等待)。问题在于,每次迭代get_weights都会变慢,我甚至会用光内存。

    def work(self, session):
        with session.as_default(), session.graph.as_default(): 
            self.private_net = ACNetwork()

            state = self.env.reset()

            while counter<TOTAL_TR_STEPS:

                action_index, action_vector = self.get_action(state)
                next_state, reward, done, info = self.env.step(action_index)
                ....# store the new data : reward, state etc...
                if done == True:
                    # end of episode
                    state = self.env.reset()
                    a_grads, c_grads = self.private_net.get_gradients()
                    self.master.update_from_gradients(a_grads, c_grads)
                    self._update_worker_net()  #this is the slow one
                !!!!!!

这是使用get_weights的函数。

def _update_worker_net(self):
      self.private_net.actor_t.set_weights(\
                               self.master.actor_t.get_weights())
      self.private_net.critic.set_weights(\
                               self.master.critic.get_weights())
return

环顾四周,我发现了一条建议使用

的帖子
  K.clear_session()
在while块的末尾(!!!!!!段)

,因为以某种方式在图上添加了新节点(?!)。但是那个onle返回了一个错误:

AssertionError: Do not use tf.reset_default_graph() to clear nested graphs. If you need a cleared graph, exit the nesting and create a new graph.

是否有更快的重量传递方式?有没有不添加新节点的方法(如果确实发生了这种情况?)

1 个答案:

答案 0 :(得分:2)

通常在将新节点动态添加到图形时发生。情况示例:

while True:
    grad_op = optimizer.get_gradients()
    session.run([gradients])

get_gradients将在其中向图添加新操作。无论调用多少次,get_gradients返回的操作都不会更改,因此,一个调用就足够了。重写它的正确方法是:

grad_op = optimizer.get_gradients()
while True:
    session.run([gradients])

类似的事情可能正在您的代码中发生。尝试确保您不在while循环内构造新操作。