如何在张量流中水平翻转标记的方向

时间:2018-05-20 21:44:07

标签: python tensorflow

我正在使用数据集API扩充数据,以便在tensorflow中学习,其中扩充是作为图形的一部分编写的。对于我的应用程序,我需要修改标签和图像。标签将方向编码为:

0: up
1: right
2: down
3: left

对于旋转增强,我可以这样做:

rotated_image = tf.image.rot90(image, 1)
rotated_label = (label + 1) % 4

我可以使用以下方法水平翻转特征图像:

hflipped_image = tf.image.flip_left_right(image)

但我无法弄清楚如何翻转左边的标签< - >右。

如果标签为1则应为3,反之亦然。 0和2应保持不变。我怎么能这样做?

1 个答案:

答案 0 :(得分:1)

flipped_label = tf.gather([0, 3, 2, 1], label)