Tensorflow高效的Cyclegan历史记录池

时间:2019-02-03 23:35:43

标签: python tensorflow generative-adversarial-network

CycleGAN paper

提到了用于区分的历史记录池。因此我们保持最后一个来自发生器的50个样本,并将其馈入鉴别器。没有历史记录,这很简单,我们可以利用tf.data.Dataset和迭代器将数据插入网络。但是在历史记录池中,我没有获得如何使用tf.data.Dataset API的信息。 训练循环中的代码看起来像

fx, fy = sess.run(model_ops['fakes'], feed_dict={
    self.cur_x: cur_x,
    self.cur_y: cur_y,
})

cur_x, cur_y = sess.run([self.X_feed.feed(), self.Y_feed.feed()])
feeder_dict = {
    self.cur_x: cur_x,
    self.cur_y: cur_y,
    self.prev_fake_x: x_pool.query(fx, step),
    self.prev_fake_y: y_pool.query(fy, step),
}
# self.cur_x, self.cur_y, self.prev_fake_x, self.prev_fake_y are just placeholders
# x_pool and y_pool are simple wrappers for random sampling from the history pool and saving new images to the pool
for _ in range(dis_train):
    sess.run(model_ops['train']['dis'], feed_dict=feeder_dict)
for _ in range(gen_train):
    sess.run(model_ops['train']['gen'], feed_dict=feeder_dict)

困扰我的是代码的效率低下,例如无法像tf.data API的预取那样在训练过程中预加载下一批,但是我看不到任何使用tf.data API的方法。 它是否提供某种历史记录池,我可以将其用于预取并总体上优化数据加载模型? 同样,当我在鉴别器列op和发电机列op之间有一定比例时,也会出现类似的问题。 例如,如果我想每1步鉴别器运行2步发电机训练操作,是否可以使用相同的数据来完成?因为使用tf.data API,每次调用sess.run时都会从迭代器中提取新示例。

有什么方法可以正确有效地实现这一目标?

1 个答案:

答案 0 :(得分:0)

因此,我发现在TFGAN tensorflow contrib存储库中已经实现了历史记录池。