如何在Tensorflow中的图形运行时中选择三个数据输入管道之一?

时间:2018-09-30 21:08:39

标签: tensorflow graph

构建图形时,我构建了三个数据输入管道。

images_pipe_1 = input_images('list1')
images_pipe_2 = input_images('list2')
images_pipe_3 = input_images('list3')

我想在图形运行时中根据global_step选择其中之一:

if global_step < 2000:
  data input pipeline = images_pipe_1
if global_step >= 2000 and global_step < 5000
  data input pipeline = images_pipe_2
if global_step >= 5000
  data input pipeline = images_pipe_3

但是在tensorflow中,像global_step这样的变量是张量,它们应该由tf函数而不是python操作。 我尝试使用tf.cond,但是它只能解决两个选项的问题。

images_pipe = tf.cond(tf.greater(global_step, tf.constant(2000, tf.int64)), lambda:images_pipe_2, lambda:images_pipe_1)

在这种情况下,有三个选项。我不知道该如何解决。感谢您的提前帮助。

1 个答案:

答案 0 :(得分:0)

我通过tf.case解决了

pipeline = tf.case({tf.greater(global_step, tf.constant(5000,tf.int64)):images-pipe_3, tf.less(global_step, tf.constant(2000,tf.int64):images_pipe_1)}, default=images_pipe_2, exclusive=True)