在Tensorflow数据加载期间运行python函数

时间:2018-02-12 19:25:04

标签: python numpy tensorflow

我有一个使用string_input_producer函数和decode_png函数加载训练数据的函数。我希望以动态方式对图像运行过滤器,而不是预处理所有数据。该函数(第60行 - bm3d)将numpy ndarray作为输入,并将过滤后的数组作为输出。

我无法弄清楚如何将此功能应用于每个输入图像。必须在调整大小之前在代码中应用它。怎么办呢?

def load_examples():
if a.input_dir is None or not os.path.exists(a.input_dir):
    raise Exception("input_dir does not exist")

input_paths = glob.glob(os.path.join(a.input_dir, "*.jpg"))
decode = tf.image.decode_jpeg
if len(input_paths) == 0:
    input_paths = glob.glob(os.path.join(a.input_dir, "*.png"))
    decode = tf.image.decode_png

if len(input_paths) == 0:
    raise Exception("input_dir contains no image files")

def get_name(path):
    name, _ = os.path.splitext(os.path.basename(path))
    return name

# if the image names are numbers, sort by the value rather than asciibetically
# having sorted inputs means that the outputs are sorted in test mode
if all(get_name(path).isdigit() for path in input_paths):
    input_paths = sorted(input_paths, key=lambda path: int(get_name(path)))
else:
    input_paths = sorted(input_paths)

with tf.name_scope("load_images"):
    path_queue = tf.train.string_input_producer(input_paths, shuffle=a.mode == "train")
    reader = tf.WholeFileReader()
    paths, contents = reader.read(path_queue)
    raw_input = decode(contents)
    raw_input = tf.image.convert_image_dtype(raw_input, dtype=tf.float32)

    assertion = tf.assert_equal(tf.shape(raw_input)[2], 1, message="image does not have 1 channel")
    with tf.control_dependencies([assertion]):
        raw_input = tf.identity(raw_input)

    raw_input.set_shape([None, None, 1])

    # break apart image pair and move to range [-1, 1]
    width = tf.shape(raw_input)[1] # [height, width, channels]
    flag_input = (a.which_direction == "AtoB")
    a_images = preprocess(raw_input[:,:width//2,:])
    b_images = preprocess(raw_input[:,width//2:,:])

if a.which_direction == "AtoB":
    inputs, targets = [a_images, b_images]
elif a.which_direction == "BtoA":
    inputs, targets = [b_images, a_images]
else:
    raise Exception("invalid direction")

# synchronize seed for image operations so that we do the same operations to both
# input and output images
seed = random.randint(0, 2**31 - 1)
def transform(image, inputs=True):
    r = image
    if a.flip:
        r = tf.image.random_flip_left_right(r, seed=seed)

    if a.bm3d and inputs:
        r = tf.py_func(bm3d_filt, [r], tf.float32)

    # area produces a nice downscaling, but does nearest neighbor for upscaling
    # assume we're going to be doing downscaling here
    r = tf.image.resize_images(r, [a.scale_size, a.scale_size], method=tf.image.ResizeMethod.AREA)

    offset = tf.cast(tf.floor(tf.random_uniform([2], 0, a.scale_size - CROP_SIZE + 1, seed=seed)), dtype=tf.int32)
    if a.scale_size > CROP_SIZE:
        r = tf.image.crop_to_bounding_box(r, offset[0], offset[1], CROP_SIZE, CROP_SIZE)
    elif a.scale_size < CROP_SIZE:
        raise Exception("scale size cannot be less than crop size")
    return r

with tf.name_scope("input_images"):
    input_images = transform(inputs, inputs=True)

with tf.name_scope("target_images", inputs=False):
    target_images = transform(targets)

paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images], batch_size=a.batch_size)
steps_per_epoch = int(math.ceil(len(input_paths) / a.batch_size))

return Examples(
    paths=paths_batch,
    inputs=inputs_batch,
    targets=targets_batch,
    count=len(input_paths),
    steps_per_epoch=steps_per_epoch,
)

Python的错误是:

Traceback (most recent call last):
File "pix2pix_mod.py", line 709, in <module>
  main()
File "pix2pix_mod.py", line 508, in main
  examples = load_examples()
File "pix2pix_mod.py", line 192, in load_examples
  input_images = transform(inputs, inputs=True)
File "pix2pix_mod.py", line 182, in transform
  r = tf.image.resize_images(r, [a.scale_size, a.scale_size], 
                    method=tf.image.ResizeMethod.AREA)
File "/home/ne63wog/anaconda3/envs/pix2pix/lib/python3.5/site-
packages/tensorflow/python/ops/image_ops_impl.py", line 741, in 
resize_images
   raise ValueError('\'images\' contains no shape.')
ValueError: 'images' contains no shape.

0 个答案:

没有答案