我想应用Tensorflow's Dataset API来训练每次通过网络传播一批数据时更改的数据集。
我遇到了下面的代码,该代码使用feed_dict实现向Tensorflow馈送数据,我想对其进行修改以使用Tensorflow API,因为Tensorflow自己说
代码的相关部分(关于Q学习的实现)是:
def generate_session(t_max=1000, epsilon=0, train=False):
"""play env with approximate q-learning agent and train it at the same time"""
total_reward = 0
s = env.reset()
for t in range(t_max):
a = get_action(s, epsilon=epsilon)
next_s, r, done, _ = env.step(a)
if train:
sess.run(train_step,feed_dict={
states_ph: [s], actions_ph: [a], rewards_ph: [r],
next_states_ph: [next_s], is_done_ph: [done]
})
total_reward += r
s = next_s
if done: break
return total_reward
我想使用Tensorflow Data API,但是这里的问题是,馈送的所有数据:s, a, r, next_s, is_done_ph
取决于训练迭代的输出。换句话说,s, a, r, next_s, is_done_ph
处t=50
的输入值由s, a, r, next_s, is_done_ph
处t=49
的输出创建。这是因为
a = get_action(s, epsilon=epsilon)
根据预先训练步骤中s的输出创建新动作 然后
next_s, r, done, _ = env.step(a)
基本上为我们提供了训练循环的其余新输入。
现在的问题是,Tensorflow Dataset API中的示例使用的是训练开始之前就已经知道的训练数据,但是我不确定如何用这个不断发展的数据集来实现Tensorflow Dataset API。