TensorFlow 2.0:使用@ tf.function装饰器的函数不会使用numpy函数

时间:2019-04-14 20:04:54

标签: python numpy tensorflow tensorflow2.0

我正在编写一个函数来在TensorFlow 2.0中实现模型。它需要image_batch(一批numpy RGB格式的图像数据)并执行一些我需要的特定数据增强任务。导致我出现问题的行是:

@tf.function
def augment_data(image_batch, labels):
    import numpy as np
    from tensorflow.image import flip_left_right

    image_batch = np.append(image_batch, flip_left_right(image_batch), axis=0)

    [ ... ]
当我将numpy装饰器放在

.append()的{​​{1}}函数上时,它不再起作用。它返回:

  

ValueError:零维数组不能串联

当我在函数外部使用@tf.function命令时,或者顶部没有使用np.append()时,代码运行都没有问题。

这正常吗?我是否必须删除装饰器才能使其正常工作?还是由于TensorFlow 2.0仍然是beta版本而导致此错误?在那种情况下,我该如何解决?

1 个答案:

答案 0 :(得分:2)

只需将numpy ops包装到tf.py_function

def append(image_batch, tf_func):
    return np.append(image_batch, tf_func, axis=0)

@tf.function
def augment_data(image_batch):
    image = tf.py_function(append, inp=[image_batch, tf.image.flip_left_right(image_batch)], Tout=[tf.float32])
    return image