我正在尝试从tfrecord文件中读取图像数据。我写入编码图像的tf记录文件的代码如下:
def create_tf_example(example):
height = 256
width = 256
depth =3
filename = example['Title']
filename = filename.encode()
path = example['path']
path = path .replace('.json','')
example['path'] = path
with open(example['path'],'rb') as f: #encoding the image
encoded_image_data = f.read()
当我尝试读取编码图像时:
def read_and_decode (filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image/height' : tf.FixedLenFeature([], tf.int64),
'image/width' : tf.FixedLenFeature([],tf.int64),
'image/depth' : tf.FixedLenFeature([], tf.int64),
'image/encoded': tf.FixedLenFeature([], tf.string),
})
image = tf.decode_raw(features['image/encoded'], tf.uint8)
height = tf.cast(features['image/height'], tf.int64)
width = tf.cast(features['image/width'], tf.int64)
depth = tf.cast(features['image/depth'], tf.int64)
image_shape = tf.stack([height,width,depth])
image = tf.reshape(image, image_shape)
images = image
return images
它给了我一个错误:
InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 11566 values, but the requested shape has 196608
[[Node: Reshape = Reshape[T=DT_UINT8, Tshape=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"](DecodeRaw, stack)]]
在我看来,当我编码时,它无法对图像中的所有3个信道进行编码。请帮助