我正在尝试将这个Pix2Pix source(使用基于队列的输入管道)转换为TensorFlow数据集API。
此源从输入目录加载所有原始图像,然后使用tf.train.batch()获取该批数据以在图中进行进一步处理。因此,我正在尝试修改此load_examples()函数,使其使用数据集API。
下面是修改后的功能
def _parse_function(example, a):
features = tf.parse_single_example(
example,
features={
'image/class/label': tf.FixedLenFeature([], tf.int64),
'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/path': tf.FixedLenFeature([], dtype=tf.string, default_value='')})
label = features['image/class/label']
image_encoded = features['image/encoded']
path = features['image/path']
# Decode the JPEG.
image = tf.image.decode_jpeg(image_encoded, channels=3)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# image = tf.reshape(image, [IMAGE_WIDTH*IMAGE_HEIGHT])
with tf.name_scope("load_images"):
# Check if images have 3 channels or not i.e. rgb or not
assertion = tf.assert_equal(tf.shape(image)[2], 3, message="image does not have 3 channels")
with tf.control_dependencies([assertion]):
image = tf.identity(image)
image.set_shape([None, None, 3])
if a.lab_colorization:
# load color and brightness from image, no B image exists here
lab = rgb_to_lab(image)
L_chan, a_chan, b_chan = preprocess_lab(lab)
a_images = tf.expand_dims(L_chan, axis=2)
b_images = tf.stack([a_chan, b_chan], axis=2)
else:
# break apart image pair and move to range [-1, 1]
width = tf.shape(image)[1] # [height, width, channels]
a_images = preprocess(image[:,:width//2,:])
b_images = preprocess(image[:,width//2:,:])
# This means from font to skeleton image
if a.which_direction == "AtoB":
inputs, targets = [a_images, b_images]
# This means from skeleton to font image
elif a.which_direction == "BtoA":
inputs, targets = [b_images, a_images]
else:
raise Exception("invalid direction")
# synchronize seed for image operations so that we do the same operations to both
# input and output images
# Transform function simply applies some preprocessing on input and target image to upscale the size etc
seed = random.randint(0, 2**31 - 1)
def transform(image):
r = image
# Just flip image of hangul or skeleton from left to right
# if a.flip:
r = tf.image.random_flip_left_right(r, seed=seed)
# area produces a nice downscaling, but does nearest neighbor for upscaling
# assume we're going to be doing downscaling here
r = tf.image.resize_images(r, [a.scale_size, a.scale_size], method=tf.image.ResizeMethod.AREA)
offset = tf.cast(tf.floor(tf.random_uniform([2], 0, a.scale_size - CROP_SIZE + 1, seed=seed)), dtype=tf.int32)
if a.scale_size > CROP_SIZE:
r = tf.image.crop_to_bounding_box(r, offset[0], offset[1], CROP_SIZE, CROP_SIZE)
elif a.scale_size < CROP_SIZE:
raise Exception("scale size cannot be less than crop size")
return r
with tf.name_scope("input_images"):
input_images = transform(inputs)
with tf.name_scope("target_images"):
target_images = transform(targets)
# Represent the label as a one hot vector.
label = tf.stack(tf.one_hot(label, num_classes))
return input_images, target_images, label, path
# This function is for loading hangul-skeleton images from the directory and apply some preprocessing
def load_examples(a):
total_records = 0
if a.mode == "test":
print('We are testing the model, so we will choose test tfrecords file')
tf_record_pattern = os.path.join(tfrecords_dir, '%s-*' % 'test')
test_data_files = tf.gfile.Glob(tf_record_pattern)
# Create testing dataset input pipeline.
test_dataset = tf.data.TFRecordDataset(test_data_files) \
.map(lambda example: _parse_function(example, a)) \
.batch(a.batch_size) \
.prefetch(1)
iterator = test_dataset.make_one_shot_iterator()
batch = iterator.get_next()
# Function for getting the total no of records
for fn in test_data_files:
for record in tf.python_io.tf_record_iterator(fn):
total_records += 1
else:
print('We are training the model, so we will choose train tfrecords file')
tf_record_pattern = os.path.join(tfrecords_dir, '%s-*' % 'train')
train_data_files = tf.gfile.Glob(tf_record_pattern)
# Create training dataset input pipeline.
train_dataset = tf.data.TFRecordDataset(train_data_files) \
.map(lambda example: _parse_function(example, a)) \
.shuffle(1000) \
.repeat(count=None) \
.batch(a.batch_size) \
.prefetch(1)
iterator = train_dataset.make_one_shot_iterator()
batch = iterator.get_next()
# Function for getting the total no of records
for fn in train_data_files:
for record in tf.python_io.tf_record_iterator(fn):
total_records += 1
# Find the batch size image paths, input images, and target images respectively.
# batch size pulls that much images from the queue pipepline
# Calculate no. of steps required of doing one epoch
input_images, target_images, label, path = batch
steps_per_epoch = int(math.ceil(total_records / a.batch_size))
return Examples(
paths=path,
inputs=input_images,
targets=target_images,
count=total_records,
steps_per_epoch=steps_per_epoch,
)
由于我是Tensorflow的新手,所以我只想知道这是使用数据集API的正确方法。通常,我们在sess.run()中调用迭代器,但是由于此源代码处理图形内的数据,因此我不知道是否可以在load_examples()函数内使用会话对象。
到目前为止,一切似乎都可以正常工作,但是我看到的唯一问题是我生成的某些图像被翻转了,我不知道为什么?翻转是上述解析功能的一部分。
有没有一种方法可以改善代码的结构?