将数据显式馈送到'tf.data.Dataset'的性能是否会受到影响?

时间:2019-05-13 01:33:32

标签: python tensorflow tensorflow-datasets

我正在实现RL算法,并使用tf.data.Dataset(与prefetch一起)将数据馈送到神经网络。但是,为了与环境交互,我必须通过feed_dict显式提供数据以采取措施。我想知道将feed_dictDataset一起使用是否会损害速度。

这是我的代码的简化版本

# code related to Dataset
ds = tf.data.Dataset.from_generator(buffer, sample_types, sample_shapes)
ds = ds.prefetch(5)
iterator = ds.make_one_shot_iterator()
samples = iterator.get_next(name='samples')
# pass samples to network
# network training, no feed_dict is needed because of Dataset
sess.run([self.opt_op])
# run the actor network to choose an action at the current state.
# manually feed the current state to samples
# will this impair the performance?
action = sess.run(self.action, feed_dict={samples['state']: state})

1 个答案:

答案 0 :(得分:1)

混合使用Dataset和feed_dict没什么错。如果您提供给feed_dict的状态很大,则可能会导致GPU使用不足,具体取决于数据的大小。但这绝不会与使用数据集有关。

存在Dataset API的原因之一是避免模型匮乏并在训练期间提高GPU利用率。饥饿可能是由于数据从一个位置复制到另一个位置的原因而发生的:磁盘到内存,内存到GPU内存,由您命名。数据集尝试尽早开始执行大量的IO操作,以免在处理下一批数据时出现模型不足的情况。因此,基本上,数据集会尝试减少批次之间的时间。

在您的情况下,您可能不会因为使用feed_dict而失去任何性能。看来您无论如何都要通过环境交互来中断执行(因此可能未充分利用GPU)。

如果您想确定的话,请在使用feed_dict馈送实际状态时调整性能,而不要用恒定张量替换状态用法并比较速度。