构建图形时,我构建了三个数据输入管道。
images_pipe_1 = input_images('list1')
images_pipe_2 = input_images('list2')
images_pipe_3 = input_images('list3')
我想在图形运行时中根据global_step选择其中之一:
if global_step < 2000:
data input pipeline = images_pipe_1
if global_step >= 2000 and global_step < 5000
data input pipeline = images_pipe_2
if global_step >= 5000
data input pipeline = images_pipe_3
但是在tensorflow中,像global_step这样的变量是张量,它们应该由tf函数而不是python操作。 我尝试使用tf.cond,但是它只能解决两个选项的问题。
images_pipe = tf.cond(tf.greater(global_step, tf.constant(2000, tf.int64)), lambda:images_pipe_2, lambda:images_pipe_1)
在这种情况下,有三个选项。我不知道该如何解决。感谢您的提前帮助。
答案 0 :(得分:0)
我通过tf.case解决了
pipeline = tf.case({tf.greater(global_step, tf.constant(5000,tf.int64)):images-pipe_3, tf.less(global_step, tf.constant(2000,tf.int64):images_pipe_1)}, default=images_pipe_2, exclusive=True)