如何将基于队列的输入管道转换为TensorFlow数据集API

时间:2019-05-22 10:12:31

标签: python tensorflow generative-adversarial-network

我正在尝试将这个Pix2Pix source(使用基于队列的输入管道)转换为TensorFlow数据集API。

此源从输入目录加载所有原始图像,然后使用tf.train.batch()获取该批数据以在图中进行进一步处理。因此,我正在尝试修改此load_examples()函数,使其使用数据集API。

下面是修改后的功能

def _parse_function(example, a):
features = tf.parse_single_example(
    example,
    features={
        'image/class/label': tf.FixedLenFeature([], tf.int64),
        'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
        'image/path': tf.FixedLenFeature([], dtype=tf.string, default_value='')})
label = features['image/class/label']
image_encoded = features['image/encoded']
path = features['image/path']

# Decode the JPEG.
image = tf.image.decode_jpeg(image_encoded, channels=3)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# image = tf.reshape(image, [IMAGE_WIDTH*IMAGE_HEIGHT])

with tf.name_scope("load_images"):
    # Check if images have 3 channels or not i.e. rgb or not
    assertion = tf.assert_equal(tf.shape(image)[2], 3, message="image does not have 3 channels")
    with tf.control_dependencies([assertion]):
        image = tf.identity(image)

    image.set_shape([None, None, 3])

    if a.lab_colorization:
        # load color and brightness from image, no B image exists here
        lab = rgb_to_lab(image)
        L_chan, a_chan, b_chan = preprocess_lab(lab)
        a_images = tf.expand_dims(L_chan, axis=2)
        b_images = tf.stack([a_chan, b_chan], axis=2)
    else:
        # break apart image pair and move to range [-1, 1]
        width = tf.shape(image)[1] # [height, width, channels]
        a_images = preprocess(image[:,:width//2,:])
        b_images = preprocess(image[:,width//2:,:])

# This means from font to skeleton image
if a.which_direction == "AtoB":
    inputs, targets = [a_images, b_images]
# This means from skeleton to font image
elif a.which_direction == "BtoA":
    inputs, targets = [b_images, a_images]
else:
    raise Exception("invalid direction")

# synchronize seed for image operations so that we do the same operations to both
# input and output images
# Transform function simply applies some preprocessing on input and target image to upscale the size etc
seed = random.randint(0, 2**31 - 1)
def transform(image):
    r = image
    # Just flip image of hangul or skeleton from left to right
    # if a.flip:
        r = tf.image.random_flip_left_right(r, seed=seed)

    # area produces a nice downscaling, but does nearest neighbor for upscaling
    # assume we're going to be doing downscaling here
    r = tf.image.resize_images(r, [a.scale_size, a.scale_size], method=tf.image.ResizeMethod.AREA)

    offset = tf.cast(tf.floor(tf.random_uniform([2], 0, a.scale_size - CROP_SIZE + 1, seed=seed)), dtype=tf.int32)
    if a.scale_size > CROP_SIZE:
        r = tf.image.crop_to_bounding_box(r, offset[0], offset[1], CROP_SIZE, CROP_SIZE)
    elif a.scale_size < CROP_SIZE:
        raise Exception("scale size cannot be less than crop size")
    return r

with tf.name_scope("input_images"):
    input_images = transform(inputs)

with tf.name_scope("target_images"):
    target_images = transform(targets)

# Represent the label as a one hot vector.
label = tf.stack(tf.one_hot(label, num_classes))

return input_images, target_images, label, path


# This function is for loading hangul-skeleton images from the directory and apply some preprocessing 
def load_examples(a):
total_records = 0
if a.mode == "test":
        print('We are testing the model, so we will choose test tfrecords file')
        tf_record_pattern = os.path.join(tfrecords_dir, '%s-*' % 'test')
        test_data_files = tf.gfile.Glob(tf_record_pattern)

        # Create testing dataset input pipeline.
        test_dataset = tf.data.TFRecordDataset(test_data_files) \
            .map(lambda example: _parse_function(example, a)) \
            .batch(a.batch_size) \
            .prefetch(1)

        iterator = test_dataset.make_one_shot_iterator()
        batch = iterator.get_next()

        # Function for getting the total no of records
        for fn in test_data_files:
            for record in tf.python_io.tf_record_iterator(fn):
               total_records += 1
else:
    print('We are training the model, so we will choose train tfrecords file')
    tf_record_pattern = os.path.join(tfrecords_dir, '%s-*' % 'train')
    train_data_files = tf.gfile.Glob(tf_record_pattern)

    # Create training dataset input pipeline.
    train_dataset = tf.data.TFRecordDataset(train_data_files) \
        .map(lambda example: _parse_function(example, a)) \
        .shuffle(1000) \
        .repeat(count=None) \
        .batch(a.batch_size) \
        .prefetch(1)

    iterator = train_dataset.make_one_shot_iterator()
    batch = iterator.get_next()

    # Function for getting the total no of records
    for fn in train_data_files:
        for record in tf.python_io.tf_record_iterator(fn):
           total_records += 1

# Find the batch size image paths, input images, and target images respectively. 
# batch size pulls that much images from the queue pipepline
# Calculate no. of steps required of doing one epoch
input_images, target_images, label, path = batch
steps_per_epoch = int(math.ceil(total_records / a.batch_size))

return Examples(
    paths=path,
    inputs=input_images,
    targets=target_images,
    count=total_records,
    steps_per_epoch=steps_per_epoch,
)

由于我是Tensorflow的新手,所以我只想知道这是使用数据集API的正确方法。通常,我们在sess.run()中调用迭代器,但是由于此源代码处理图形内的数据,因此我不知道是否可以在load_examples()函数内使用会话对象。

到目前为止,一切似乎都可以正常工作,但是我看到的唯一问题是我生成的某些图像被翻转了,我不知道为什么?翻转是上述解析功能的一部分。

有没有一种方法可以改善代码的结构?

0 个答案:

没有答案