我在TensorFlow中的数据扩充代码似乎正在运行,但没有正确地用于预期目的。我有一个张量为'in_batch'的图像,形状为[in_batch, height, width, 3]
,并应用下面的data_augment
函数。
与没有它的模型相比,我应用数据扩充时运行时间增加,因此该函数肯定正在运行。然而,测试精度显着下降,即约80%没有数据增加到约10%。发生了什么事?
def prep_data_augment(input_tensor, i):
image = tf.slice(input_tensor, [i, 0, 0, 0], [1, 32, 32, 3])[0]
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=63)
image = tf.image.random_contrast(image, lower=0.2, upper=1.8)
image = tf.expand_dims(image, 0)
return image
def data_augment(input_tensor, in_batch):
fake_i = tf.Variable(0, dtype=tf.int32, trainable=False)
i = 0
output_tensor = prep_data_augment(input_tensor, i)
i += 1
def while_body(output_tensor, i):
output_tensor = tf.concat(0, [output_tensor, prep_data_augment(input_tensor, i)])
i += 1
return (output_tensor, i)
(output_tensor, i) = tf.while_loop(lambda output_tensor, i: tf.less(i, in_batch),
lambda output_tensor, i: while_body(output_tensor, i),
(output_tensor, i),
shape_invariants = (tf.TensorShape([None, 32, 32, 3]), fake_i.get_shape()))
return output_tensor
[更新] 不使用while循环,可以使用如下的map函数简化上述函数。
def prep_data_augment(image):
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=63/255.0)
image = tf.image.random_contrast(image, lower=0.2, upper=1.8)
return image
def data_augment(input_tensor):
output_tensor = tf.map_fn(prep_data_augment, input_tensor)
return output_tensor