使用python函数和tf.Dataset API

时间:2018-06-12 16:14:19

标签: python tensorflow deep-learning dataset tensorflow-datasets

我正在寻找动态阅读的图片,并为我的图像分割问题应用数据增强。从我到目前为止看来,最好的方法是具有tf.Dataset功能的.map API。

但是,从我看过的例子中我认为我必须使我的所有功能都适应张量流式(使用tf.cond而不是if等)。问题是我需要应用一些非常复杂的功能。因此我正在考虑使用这样的tf.py_func

import tensorflow as tf

img_path_list = [...]   # List of paths to read
mask_path_list = [...]  # List of paths to read

dataset = tf.data.Dataset.from_tensor_slices((img_path_list, mask_path_list))

def parse_function(img_path_list, mask_path_list):
    '''load image and mask from paths'''
    return img, mask

def data_augmentation(img, mask):
    '''process data with complex logic'''
    return aug_img, aug_mask

# py_func wrappers
def parse_function_wrapper(img_path_list, mask_path_list):
    return tf.py_func(func=parse_function,
                      inp=(img_path_list, mask_path_list),
                      Tout=(tf.float32, tf.float32))

def data_augmentation_wrapper(img, mask):
    return tf.py_func(func=data_augmentation,
                      inp=(img, mask),
                      Tout=(tf.float32, tf.float32))        

# Maps py_funcs to dataset
dataset = dataset.map(parse_function_wrapper,
                      num_parallel_calls=4)
dataset = dataset.map(data_augmentation_wrapper,
                      num_parallel_calls=4)

dataset = dataset.batch(32)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()

但是,从this answer开始,似乎使用py_func进行并行操作不起作用。还有其他选择吗?

1 个答案:

答案 0 :(得分:1)

py_func受到python GIL的限制,所以你不会在那里得到很多并行性。您最好的办法是在tensorflow中正确编写数据扩充(或预先计算并将其序列化为磁盘)。

如果你想在tensorflow中编写它,你可以尝试使用tf.contrib.autograph将简单的python ifs和for循环转换为tf.conds和tf.while_loops,这可能会简化你的代码。