我有一个使用string_input_producer函数和decode_png函数加载训练数据的函数。我希望以动态方式对图像运行过滤器,而不是预处理所有数据。该函数(第60行 - bm3d)将numpy ndarray作为输入,并将过滤后的数组作为输出。
我无法弄清楚如何将此功能应用于每个输入图像。必须在调整大小之前在代码中应用它。怎么办呢?
def load_examples():
if a.input_dir is None or not os.path.exists(a.input_dir):
raise Exception("input_dir does not exist")
input_paths = glob.glob(os.path.join(a.input_dir, "*.jpg"))
decode = tf.image.decode_jpeg
if len(input_paths) == 0:
input_paths = glob.glob(os.path.join(a.input_dir, "*.png"))
decode = tf.image.decode_png
if len(input_paths) == 0:
raise Exception("input_dir contains no image files")
def get_name(path):
name, _ = os.path.splitext(os.path.basename(path))
return name
# if the image names are numbers, sort by the value rather than asciibetically
# having sorted inputs means that the outputs are sorted in test mode
if all(get_name(path).isdigit() for path in input_paths):
input_paths = sorted(input_paths, key=lambda path: int(get_name(path)))
else:
input_paths = sorted(input_paths)
with tf.name_scope("load_images"):
path_queue = tf.train.string_input_producer(input_paths, shuffle=a.mode == "train")
reader = tf.WholeFileReader()
paths, contents = reader.read(path_queue)
raw_input = decode(contents)
raw_input = tf.image.convert_image_dtype(raw_input, dtype=tf.float32)
assertion = tf.assert_equal(tf.shape(raw_input)[2], 1, message="image does not have 1 channel")
with tf.control_dependencies([assertion]):
raw_input = tf.identity(raw_input)
raw_input.set_shape([None, None, 1])
# break apart image pair and move to range [-1, 1]
width = tf.shape(raw_input)[1] # [height, width, channels]
flag_input = (a.which_direction == "AtoB")
a_images = preprocess(raw_input[:,:width//2,:])
b_images = preprocess(raw_input[:,width//2:,:])
if a.which_direction == "AtoB":
inputs, targets = [a_images, b_images]
elif a.which_direction == "BtoA":
inputs, targets = [b_images, a_images]
else:
raise Exception("invalid direction")
# synchronize seed for image operations so that we do the same operations to both
# input and output images
seed = random.randint(0, 2**31 - 1)
def transform(image, inputs=True):
r = image
if a.flip:
r = tf.image.random_flip_left_right(r, seed=seed)
if a.bm3d and inputs:
r = tf.py_func(bm3d_filt, [r], tf.float32)
# area produces a nice downscaling, but does nearest neighbor for upscaling
# assume we're going to be doing downscaling here
r = tf.image.resize_images(r, [a.scale_size, a.scale_size], method=tf.image.ResizeMethod.AREA)
offset = tf.cast(tf.floor(tf.random_uniform([2], 0, a.scale_size - CROP_SIZE + 1, seed=seed)), dtype=tf.int32)
if a.scale_size > CROP_SIZE:
r = tf.image.crop_to_bounding_box(r, offset[0], offset[1], CROP_SIZE, CROP_SIZE)
elif a.scale_size < CROP_SIZE:
raise Exception("scale size cannot be less than crop size")
return r
with tf.name_scope("input_images"):
input_images = transform(inputs, inputs=True)
with tf.name_scope("target_images", inputs=False):
target_images = transform(targets)
paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images], batch_size=a.batch_size)
steps_per_epoch = int(math.ceil(len(input_paths) / a.batch_size))
return Examples(
paths=paths_batch,
inputs=inputs_batch,
targets=targets_batch,
count=len(input_paths),
steps_per_epoch=steps_per_epoch,
)
Python的错误是:
Traceback (most recent call last):
File "pix2pix_mod.py", line 709, in <module>
main()
File "pix2pix_mod.py", line 508, in main
examples = load_examples()
File "pix2pix_mod.py", line 192, in load_examples
input_images = transform(inputs, inputs=True)
File "pix2pix_mod.py", line 182, in transform
r = tf.image.resize_images(r, [a.scale_size, a.scale_size],
method=tf.image.ResizeMethod.AREA)
File "/home/ne63wog/anaconda3/envs/pix2pix/lib/python3.5/site-
packages/tensorflow/python/ops/image_ops_impl.py", line 741, in
resize_images
raise ValueError('\'images\' contains no shape.')
ValueError: 'images' contains no shape.