Tensorflow在循环中减慢推理速度

时间:2018-03-10 23:35:49

标签: tensorflow reinforcement-learning

我正在使用Tensorflow进行强化学习实施。在对训练程序进行剖析后,我发现了一些非常奇怪的东西:

以下代码位于训练循环中:

    state_batch, \
    action_batch, \
    reward_batch, \
    next_state_batch, \
    is_episode_finished_batch = self.data_manager.get_next_batch()

    state_batch = np.divide(state_batch, 10.0)
    next_state_batch = np.divide(next_state_batch, 10.0)

    # Calculate y for the td_error of the critic
    y_batch = []

    next_action_batch = self.actor_network.target_evaluate(
        next_state_batch, action_batch)

    q_value_batch = self.critic_network.target_evaluate(
        next_state_batch, next_action_batch)

    for i in range(0, self.batch_size):
        if is_episode_finished_batch[i]:
            y_batch.append([reward_batch[i]])
        else:
            y_batch.append(reward_batch[i] + GAMMA * q_value_batch[i])

    # Now that we have the y batch, train the critic
    self.critic_network.train(y_batch, state_batch, action_batch)

    # Then get the action gradient batch and adapt the gradient with the gradient inverting method
    action_batch_for_gradients = self.actor_network.evaluate(
        state_batch, action_batch)

    q_gradient_batch = self.critic_network.get_action_gradient(
        state_batch, action_batch_for_gradients)

    q_gradient_batch = self.grad_inv.invert(
        q_gradient_batch, action_batch_for_gradients)

    # Now we can train the actor
    self.actor_network.train(q_gradient_batch, state_batch, action_batch)

actor_networkcritic_network是两个在演员评论算法中实现演员和评论家的类。他们每个人都有自己的网络和操作,但都在同一个图表中,并且将在同一个会话中运行。每个成员函数(如evaluate,train ...)都包含一个session.run,并通过传递参数来提供所需的数据。

我观察到action_batch_for_gradients运行极慢,花费0.x秒进行一次推理,甚至比self.critic_network.train慢得多。 action_batch_for_gradients只是演员网络中的推理操作以获取操作。然后我复制这一行并复制它,发现在action_batch_for_gradients之后只有第一个self.critic_network.train很慢,但第二个是正向操作的正常速度。我认为这与在图表之间切换,在培训网络和在另一个网络中转发之间有关。但我不知道如何避免。

我在stackoverflow上发现了一些关于在循环中使用相同图形的讨论,而不是每次都建立新图形,以加快使用张量流。但是我已经预先构建了图形,并且只在训练循环中运行图形的不同部分。所以我不知道我在这个循环训练中错误地使用tensorflow。我使用的是Tensorflow 1.6。

感谢您的帮助!

0 个答案:

没有答案