将Tensorflow Dataset API应用于在训练过程中更改的数据集

时间:2018-08-04 19:41:45

标签: python tensorflow tensorflow-datasets q-learning

我想应用Tensorflow's Dataset API来训练每次通过网络传播一批数据时更改的数据集。

我遇到了下面的代码,该代码使用feed_dict实现向Tensorflow馈送数据,我想对其进行修改以使用Tensorflow API,因为Tensorflow自己说

"Feeding" is the least efficient way to feed data into a TensorFlow program and should only be used for small experiments and debugging.

代码的相关部分(关于Q学习的实现)是:

def generate_session(t_max=1000, epsilon=0, train=False):
    """play env with approximate q-learning agent and train it at the same time"""
    total_reward = 0
    s = env.reset()

    for t in range(t_max):
        a = get_action(s, epsilon=epsilon)       
        next_s, r, done, _ = env.step(a)

        if train:
            sess.run(train_step,feed_dict={
                states_ph: [s], actions_ph: [a], rewards_ph: [r], 
                next_states_ph: [next_s], is_done_ph: [done]
            })

        total_reward += r
        s = next_s
        if done: break

    return total_reward

我想使用Tensorflow Data API,但是这里的问题是,馈送的所有数据:s, a, r, next_s, is_done_ph取决于训练迭代的输出。换句话说,s, a, r, next_s, is_done_pht=50的输入值由s, a, r, next_s, is_done_pht=49的输出创建。这是因为

a = get_action(s, epsilon=epsilon)       

根据预先训练步骤中s的输出创建新动作 然后

next_s, r, done, _ = env.step(a)

基本上为我们提供了训练循环的其余新输入。

现在的问题是,Tensorflow Dataset API中的示例使用的是训练开始之前就已经知道的训练数据,但是我不确定如何用这个不断发展的数据集来实现Tensorflow Dataset API。

0 个答案:

没有答案