我使用Dataset API加载数据,并且某些输入图像无效。因此,我想使用数据集API的过滤器功能跳过它们。我知道可以通过使用tf.py_func并尝试在此处加载图像来实现,但是我想知道是否可以在没有py_func的情况下本机完成该操作。
def is_valid_image(filename):
# what to do here? How to do try-catch or other validation in tensorflow?
return True
def load_image(filename):
image_string = tf.read_file(filename)
image = tf.image.decode_png(image_string)
return image
with tf.device('/cpu:0'):
names = tf.data.Dataset.from_tensor_slices(image_paths)
names = names.filter(is_valid_image, num_parallel_calls=10) # some images are not valid, this filters them out
images = names.map(load_image, num_parallel_calls=10)
data = images.batch(batch_size=FLAGS.batchsize)
data = data.prefetch(buffer_size=3)
iterator = data.make_one_shot_iterator()
feeder = iterator.get_next()