当我使用tf.data.TFRecordDataset读取许多tfrecords时。我从tfrecord中读取了标签和图像。我使用张量板对图像进行摘要,然后将标签写入日志文件。但是当我查看日志文件和张量板时。标签和图像不对应。如下,我的代码读取tfrecrods。
def parser(record):
features = tf.parse_single_example(record,
features={
'label': tf.FixedLenFeature([], tf.int64),
'image': tf.FixedLenFeature([], tf.string)
}) # 取出包含image和label的feature对象
recode_image = tf.decode_raw(features['image'], tf.uint8)
real_image = tf.reshape(recode_image, shape=[38, 38, 1])
lable = tf.cast(features['label'], tf.int64)
return real_image,lable
def read_data(file_path):
min_after_dequeue = 100
batch_size = 3
data=tf.data.TFRecordDataset(file_path)
dataset=data.map(parser).
shuffle(buffer_size=min_after_dequeue).
batch(batch_size=batch_size)
dataset=dataset.repeat()
dataset.prefetch(100)
iterator = dataset.make_one_shot_iterator()
image_batch, lable_batch = iterator.get_next()
image_batch=input_float(image_batch)
return image_batch,lable_batch
在主线程代码中使用read_data是:
file_list=glob.glob("./tfcode/training_image/*.tfrecord")
file_list = list(
map(lambda image: image.replace('\\', '/'), file_list))
image_batch, lable_batch= read_data(file_list)
tf.summary.image(tensor=image_batch,name="image")
input_lable = sess.run(lable_batch)
logger.info(input_lable)
以下是我在张量板上的外观: enter image description here
以下是我在日志文件中看到的内容: enter image description here
张量板摘要为[1,3,3],但日志文件为[3,3,3]的图像标签。
我该如何处理。
答案 0 :(得分:0)
为什么不使用https://www.tensorflow.org/api_docs/python/tf/image/decode_jpeg而不是tf.decode_raw
?
def parser(record):
features = tf.parse_single_example(record,
features={
'label': tf.FixedLenFeature([], tf.int64),
'image': tf.FixedLenFeature([], tf.string)
})
recode_image = tf.image.decode_jpeg(features['image'], channels=1)
real_image = tf.reshape(recode_image, shape=[38, 38])
lable = tf.cast(features['label'], tf.int64)
return real_image,lable
def read_data(file_path):
min_after_dequeue = 100
batch_size = 3
data=tf.data.TFRecordDataset(file_path)
dataset=data.map(parser).
shuffle(buffer_size=min_after_dequeue).
batch(batch_size=batch_size)
dataset=dataset.repeat()
dataset.prefetch(100)
iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()
image_batch=input_float(image_batch) # I'm assuming you are rescalling the image to [0,1]
return image_batch,label_batch