Tensorflow数据集API过滤器无效图像

时间:2018-10-08 19:55:49

标签: python tensorflow

我使用Dataset API加载数据,并且某些输入图像无效。因此,我想使用数据集API的过滤器功能跳过它们。我知道可以通过使用tf.py_func并尝试在此处加载图像来实现,但是我想知道是否可以在没有py_func的情况下本机完成该操作。

def is_valid_image(filename):
    # what to do here? How to do try-catch or other validation in tensorflow?
    return True

def load_image(filename):
    image_string = tf.read_file(filename)
    image = tf.image.decode_png(image_string)
    return image

with tf.device('/cpu:0'):
    names = tf.data.Dataset.from_tensor_slices(image_paths)
    names = names.filter(is_valid_image, num_parallel_calls=10)   # some images are not valid, this filters them out
    images = names.map(load_image, num_parallel_calls=10)
    data = images.batch(batch_size=FLAGS.batchsize)
    data = data.prefetch(buffer_size=3)
    iterator = data.make_one_shot_iterator()
    feeder = iterator.get_next()

0 个答案:

没有答案