在tensorflow中使用sess.run()有多昂贵?

时间:2018-09-06 00:07:37

标签: python performance tensorflow machine-learning

我使用tensorflow在python中编写了机器学习算法。下图显示了算法伪代码。在这种算法中,我在训练循环中多次使用sess.run()。我必须使用多个sess.run()的原因是,我必须在不同的输入端评估相同的神经网络以计算δ。由于某些原因,我仍然不知道我的代码非常慢(请参阅codereviewai以查看代码和相关问题)。

enter image description here  图片取自Richard S. Sutton和Andrew G. Barto的Reinforcement Learning An Introduction书。

我对此堆栈的疑问如下:

1)做两个sess.run()而不是一个要多少钱。例如:

要做,

sess.run([op1],feed_dict={input:data})
sess.run([op2],feed_dict={input:data}) 

而不是

sess.run([op1,op2],feed_dict={input:data})

有什么区别吗?

2)在同一步骤的不同输入处评估同一神经网络的有效方法是什么?

我目前正在按如下方式计算δ:

self.delta = self.time_step_info['r'] + (not self.time_step_info['d'])*self.gamma*sess.run(self.critic(),feed_dict={self.state_in:self.time_step_info['s1']}) - sess.run(self.critic(),feed_dict={self.state_in:self.time_step_info['s']})

1 个答案:

答案 0 :(得分:0)

对于您的第一个问题,我不确定。

对于第二个问题,您可能已经知道,输入应该是矩阵。矩阵可以包含多个X。 NN将生成一个相应的结果矩阵Y,该矩阵Y的每一行都是X中行的输出。