我试图弄清楚这段代码的作用,但是我无法弄清楚它如何传递图像以及对图像的作用。
主要代码行就是这个
images1, images2 = preprocess(images, is_train, BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH)
这很简单,它是一个获得想要的图像的函数。
现在参数图像是这样的:
images = tf.placeholder(tf.float32, [2, BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3], name='images')
is_train = tf.placeholder(tf.bool, name='is_train')
这是预处理功能:
def preprocess(images, is_train, BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH):
def train():
split = tf.split(images, [1, 1])
shape = [1 for _ in range(split[0].get_shape()[1])]
for i in range(len(split)):
split[i] = tf.reshape(split[i], [BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3])
split[i] = tf.image.resize_images(split[i], [IMAGE_HEIGHT + 8, IMAGE_WIDTH + 3])
split[i] = tf.split(split[i], shape)
for j in range(len(split[i])):
split[i][j] = tf.reshape(split[i][j], [IMAGE_HEIGHT + 8, IMAGE_WIDTH + 3, 3])
split[i][j] = tf.random_crop(split[i][j], [IMAGE_HEIGHT, IMAGE_WIDTH, 3])
split[i][j] = tf.image.random_flip_left_right(split[i][j])
split[i][j] = tf.image.random_brightness(split[i][j], max_delta=32. / 255.)
split[i][j] = tf.image.random_saturation(split[i][j], lower=0.5, upper=1.5)
split[i][j] = tf.image.random_hue(split[i][j], max_delta=0.2)
split[i][j] = tf.image.random_contrast(split[i][j], lower=0.5, upper=1.5)
split[i][j] = tf.image.per_image_standardization(split[i][j])
return [tf.reshape(tf.concat(split[0], axis=0), [BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3]),
tf.reshape(tf.concat(split[1], axis=0), [BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3])]
def val():
split = tf.split(images, [1, 1])
shape = [1 for _ in range(split[0].get_shape()[1])]
for i in range(len(split)):
split[i] = tf.reshape(split[i], [BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3])
split[i] = tf.image.resize_images(split[i], [IMAGE_HEIGHT, IMAGE_WIDTH])
split[i] = tf.split(split[i], shape)
for j in range(len(split[i])):
split[i][j] = tf.reshape(split[i][j], [IMAGE_HEIGHT, IMAGE_WIDTH, 3])
split[i][j] = tf.image.per_image_standardization(split[i][j])
return [tf.reshape(tf.concat(split[0], axis=0), [BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3]),
tf.reshape(tf.concat(split[1], axis=0), [BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3])]
return tf.cond(is_train, train, val)
这是图像的全部代码
if MODE == 'train':
tarin_num_id = get_num_id(DATA_DIR, 'train')
elif MODE == 'eval':
val_num_id = get_num_id(DATA_DIR, 'val')
images1, images2 = preprocess(images, is_train, BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH)
我不知道它将如何处理图像并将其发送到网络。
谢谢您的帮助。
我正在处理的整个代码都来自这里 https://github.com/digitalbrain79/person-reid
答案 0 :(得分:0)
此问题的答案是feed_dict是正在传递的内容,应包含所需的图像。
feed_dict = {images: test_images, is_train: False}
您可以通过像test_images这样的数组加载图像,然后将其传递给feed_dict。这样可以节省时间,因为您可以将不同的图像加载到feed_dict中,而无需更改用于训练,验证或测试的大量代码
谢谢@Chetan Vashisth指向feed_dict词典