我是TensorFlow的新手,我正在尝试建立一个使用tf.data
来摄取图像的管道。我正在使用的注释由像素注释,并存储在注释图像中。我已经创建了TFRecordData
个文件,并将它们存储在我的主目录中。我正在尝试读取它们并打印出像素值,以便确保我知道它们存储的格式。有人可以解释如何打印出这些TFRecordData
文件的内容吗?
我收到以下错误:
TypeError:提取参数具有 无效类型,必须为 字符串或张量。 (无法将PrefetchDataset转换为张量或 操作。)
下面是我的代码:
import tensorflow as tf
import cv2
import sys
import numpy as np
sess = tf.Session()
sess.run(tf.global_variables_initializer())
def parser(record):
keys_to_features = {
"image_raw": tf.FixedLenFeature([], tf.string),
"anno_raw": tf.FixedLenFeature([], tf.string)
}
parsed = tf.parse_single_example(record, keys_to_features)
image = tf.decode_raw(parsed["image_raw"], tf.uint8)
image = tf.cast(image, tf.float32)
anno = tf.decode_raw(parsed["anno_raw"], tf.uint8)
anno = tf.cast(anno, tf.float32)
return {'image': image,'anno': anno}
def input_fn(filenames):
dataset = tf.data.TFRecordDataset(filenames=filenames, num_parallel_reads=40)
dataset = dataset.apply(
tf.contrib.data.shuffle_and_repeat(1024, 1)
)
dataset = dataset.apply(
tf.contrib.data.map_and_batch(parser, 32)
)
dataset = dataset.prefetch(buffer_size=2)
return dataset
def train_input_fn():
return input_fn(filenames=["train.tfrecords", "test.tfrecords"])
data = train_input_fn()
print(sess.run(data))