我正在尝试使用tensorflow读取图像文件并从路径中获取标签,如下所示:
import tensorflow as tf
filename_queue = tf.train.string_input_producer(
tf.matching_files(
tf.constant(["./positive_images/*.jpg",
"./negative_images/*.jpg"])))
image_reader = tf.WholeFileReader()
file_name, image_file = image_reader.read(filename_queue)
label = tf.map_fn(lambda x: "positive" in x, file_name)
image = tf.image.decode_jpeg(image_file, channels=1)
with tf.Session() as sess:
# Required to get the filename matching to run.
tf.global_variables_initializer()
# Coordinate the loading of image files.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
filename1 = sess.run([label])
print(filename1)
image_tensor = sess.run([image])
print(image_tensor)
coord.request_stop()
coord.join(threads)
但是我得到了错误:
in map_fn raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar")
ValueError: elems must be a 1+ dimensional Tensor, not a scalar
读取图像似乎可以正常工作,但是无法正确解析文件名。我在做什么错了?