Tensorflow数据集生成器反色

时间:2019-03-04 09:39:39

标签: python tensorflow tensorflow-datasets

我对TF数据集生成器有问题。我不知道为什么,但是当我通过运行会话从数据集中获取图片时,它返回张量,其中颜色是反转的。我试图将BGR更改为RGB,但这不是问题。 它可以通过反转图像数组(img = 1-img)来部分解决,但我希望不会首先出现此问题。有人知道是什么原因吗?

import os
import glob
import random


import tensorflow as tf
from tensorflow import Tensor


class PairGenerator(object):
    person1 = 'img'
    person2 = 'person2'
    label = 'same_person'

    #def __init__(self, lfw_path='./tf_dataset/resources' + os.path.sep + 'lfw'):
    def __init__(self, lfw_path='/home/tom/Devel/ai-dev/tensorflow-triplet-loss/data/augmentor'):
        self.all_people = self.generate_all_people_dict(lfw_path)
        print(self.all_people.keys())

    def generate_all_people_dict(self, lfw_path):
        # generates a dictionary between a person and all the photos of that person
        all_people = {}
        for person_folder in os.listdir(lfw_path):
            person_photos = glob.glob(lfw_path + os.path.sep + person_folder + os.path.sep + '*.jpg')
            all_people[person_folder] = person_photos
        return all_people

    def get_next_pair(self):
        all_people_names = list(self.all_people.keys())

        while True:
            # draw a person at random
            person1 = random.choice(all_people_names)
            # flip a coin to decide whether we fetch a photo of the same person vs different person

            same_person = random.random() > 0.5
            if same_person:
                person2 = person1
            else:
                # repeatedly pick random names until we find a different name
                person2 = person1
                while person2 == person1:
                    person2 = random.choice(all_people_names)

            person1_photo = random.choice(self.all_people[person1])

            yield ({self.person1: person1_photo,
                    self.label: same_person})





class Inputs(object):
    def __init__(self, img: Tensor, label: Tensor):
        self.img = img        
        self.label = label

    def feed_input(self, input_img, input_label=None):
        # feed the input images that are necessary to make a prediction
        feed_dict = {self.img: input_img}

        # optionally also include the label:
        # if we're just making a prediction without calculating loss, that won't be necessary
        if input_label is not None:
            feed_dict[self.label] = input_label

        return feed_dict




class Dataset(object):
    img_resized = 'img_resized'
    label = 'same_person'

    def __init__(self, generator=PairGenerator()):
        self.next_element = self.build_iterator(generator)

    def build_iterator(self, pair_gen: PairGenerator):
        batch_size = 10
        prefetch_batch_buffer = 5

        dataset = tf.data.Dataset.from_generator(pair_gen.get_next_pair,
                                                 output_types={PairGenerator.person1: tf.string,                                                              
                                                               PairGenerator.label: tf.bool})
        dataset = dataset.map(self._read_image_and_resize)
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(prefetch_batch_buffer)
        iter = dataset.make_one_shot_iterator()
        element = iter.get_next()

        return Inputs(element[self.img_resized],                     
                      element[PairGenerator.label])

    def _read_image_and_resize(self, pair_element):
        target_size = [224, 224]


        # read images from disk
        img_file = tf.read_file(pair_element[PairGenerator.person1])
        print("////")
        print(PairGenerator.person1)
        img = tf.image.decode_image(img_file, channels=3)



        # let tensorflow know that the loaded images have unknown dimensions, and 3 color channels (rgb)
        img.set_shape([None, None, 3])

        # resize to model input size
        img_resized = tf.image.resize_images(img, target_size)
        #img_resized = tf.image.flip_up_down(img_resized)
        #img_resized = tf.image.rot90(img_resized)


        pair_element[self.img_resized] = img_resized        
        pair_element[self.label] = tf.cast(pair_element[PairGenerator.label], tf.float32)

        return pair_element




generator = PairGenerator()
iter = generator.get_next_pair()
for i in range(10):
    print(next(iter))
ds = Dataset(generator)



import matplotlib.pyplot as plt



imgplot = plt.imshow(out)
imgplot = plt.imshow(1 - out)

1 个答案:

答案 0 :(得分:0)

好的,解决方法是

imgplot = plt.imshow(out / 255)