在Tensorflow中设计Deep Q网络丢失功能

时间:2018-10-30 13:41:10

标签: python tensorflow

我正在使用我的第一个深度Q学习网络,但是在Tensorflow中设计损失函数时遇到了困难。由于损失函数同时使用当前网络权重和先前过时的权重(请参阅第二个公式here)。我只看到两种方法可以做到这一点:

  1. 预先预测过时的网络在更新时的所有可能状态,然后使用这些值
  2. 使用tf.train.Saver()并为每个反向传播在模型之间切换

在Tensorflow中是否还有其他更适合我的东西?

1 个答案:

答案 0 :(得分:2)

您需要将过时的权重存储在不同的tf.Varables中,以便以后使用。我无权访问您的任何代码,但建议您运行两次模型构造,并使用其中一个作为简单存储。另一个解决方案是使用每个变量中的两个来修改当前图形,并创建与副本的连接。

也就是说,如果您正在创建一个TensorFlow变量A,并且想要将其先前的值存储为B,则可以执行以下操作:

A = tf.Variable(5)
B = tf.Variable(0)

# Use A to do something
A = A * 5
# Store the value of A in B
B = A

with tf.Session() as sess:
    sess.run(B) # Store A in B
    sess.run(A) # Run an update on A

print A, B