tf占位符功能做什么

时间:2019-05-15 05:06:46

标签: python-3.x tensorflow image-processing

我试图弄清楚这段代码的作用,但是我无法弄清楚它如何传递图像以及对图像的作用。

主要代码行就是这个

images1, images2 = preprocess(images, is_train, BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH)

这很简单,它是一个获得想要的图像的函数。

现在参数图像是这样的:

    images = tf.placeholder(tf.float32, [2, BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3], name='images')
    is_train = tf.placeholder(tf.bool, name='is_train')

这是预处理功能:

def preprocess(images, is_train, BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH):
def train():
    split = tf.split(images, [1, 1])
    shape = [1 for _ in range(split[0].get_shape()[1])]
    for i in range(len(split)):
        split[i] = tf.reshape(split[i], [BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3])
        split[i] = tf.image.resize_images(split[i], [IMAGE_HEIGHT + 8, IMAGE_WIDTH + 3])
        split[i] = tf.split(split[i], shape)
        for j in range(len(split[i])):
            split[i][j] = tf.reshape(split[i][j], [IMAGE_HEIGHT + 8, IMAGE_WIDTH + 3, 3])
            split[i][j] = tf.random_crop(split[i][j], [IMAGE_HEIGHT, IMAGE_WIDTH, 3])
            split[i][j] = tf.image.random_flip_left_right(split[i][j])
            split[i][j] = tf.image.random_brightness(split[i][j], max_delta=32. / 255.)
            split[i][j] = tf.image.random_saturation(split[i][j], lower=0.5, upper=1.5)
            split[i][j] = tf.image.random_hue(split[i][j], max_delta=0.2)
            split[i][j] = tf.image.random_contrast(split[i][j], lower=0.5, upper=1.5)
            split[i][j] = tf.image.per_image_standardization(split[i][j])
    return [tf.reshape(tf.concat(split[0], axis=0), [BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3]),
        tf.reshape(tf.concat(split[1], axis=0), [BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3])]
def val():
    split = tf.split(images, [1, 1])
    shape = [1 for _ in range(split[0].get_shape()[1])]
    for i in range(len(split)):
        split[i] = tf.reshape(split[i], [BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3])
        split[i] = tf.image.resize_images(split[i], [IMAGE_HEIGHT, IMAGE_WIDTH])
        split[i] = tf.split(split[i], shape)
        for j in range(len(split[i])):
            split[i][j] = tf.reshape(split[i][j], [IMAGE_HEIGHT, IMAGE_WIDTH, 3])
            split[i][j] = tf.image.per_image_standardization(split[i][j])
    return [tf.reshape(tf.concat(split[0], axis=0), [BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3]),
        tf.reshape(tf.concat(split[1], axis=0), [BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3])]
return tf.cond(is_train, train, val)

这是图像的全部代码

 if MODE == 'train':
    tarin_num_id = get_num_id(DATA_DIR, 'train')
elif MODE == 'eval':
    val_num_id = get_num_id(DATA_DIR, 'val')
images1, images2 = preprocess(images, is_train, BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH)

我不知道它将如何处理图像并将其发送到网络。

谢谢您的帮助。

我正在处理的整个代码都来自这里 https://github.com/digitalbrain79/person-reid

1 个答案:

答案 0 :(得分:0)

此问题的答案是feed_dict是正在传递的内容,应包含所需的图像。      feed_dict = {images: test_images, is_train: False} 您可以通过像test_images这样的数组加载图像,然后将其传递给feed_dict。这样可以节省时间,因为您可以将不同的图像加载到feed_dict中,而无需更改用于训练,验证或测试的大量代码 谢谢@Chetan Vashisth指向feed_dict词典