我有一个处理过的图像张量(例如25x25x1张量),我将其存储在TFRecord文件中。我正在使用tf.Example方法进行序列化并准备TFRecord文件。
当我读回它时,张量显示为(644 ,?),并且无法将其重塑为25x25x1。我想知道这些额外的值如何进入张量。
如何解决此问题?请帮忙。
我在这里遵循Chip Huyen提供的代码示例-https://github.com/chiphuyen/stanford-tensorflow-tutorials/blob/master/2017/examples/09_tfrecord_example.py
这是我用于写入文件的代码。
channels = 1
resize_height = 25
resize_width = 25
target_file = ./some_JPEG_file.jpeg
shape = np.array([resize_height, resize_width,channels],np.int32).tobytes()
def return_image_resized(filename,channels):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string)
image_resized = tf.image.resize_images(image_decoded, [resize_height, resize_width],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
if channels ==1:
image_resized = tf.image.rgb_to_grayscale(image_resized)
return image_resized
def write_to_tfrecord(writer, target, shape, tfrecord_file):
example = tf.train.Example(features=tf.train.Features(feature={
'target': _bytes_feature(target),
'shape': _bytes_feature(shape)
}))
writer.write(example.SerializeToString())
target = tf.serialize_tensor(return_image_resized(target_file,channels))
writer = tf.python_io.TFRecordWriter(output_file)
write_to_tfrecord(writer, target, shape, output_file)
writer.close()
但是,当我尝试使用以下代码读取文件时,出现错误消息“ InvalidArgumentError:要重塑的输入是具有644个值的张量,但请求的形状为625”
def read_from_tfrecord(filenames):
tfrecord_file_queue = tf.train.string_input_producer(filenames, name='queue')
reader = tf.TFRecordReader()
_, tfrecord_serialized = reader.read(tfrecord_file_queue)
tfrecord_features = tf.parse_single_example(tfrecord_serialized,
features={
'target': tf.FixedLenFeature([], tf.string),
'shape': tf.FixedLenFeature([], tf.string),
}, name='features')
target = tf.decode_raw(tfrecord_features['target'], tf.uint8)
shape = tf.decode_raw(tfrecord_features['shape'], tf.int32)
#Error is coming from the below line of code when I try to reshape it
target = tf.reshape(target, shape)
return target, shape
def read_tfrecord(tfrecord_file):
target, shape = read_from_tfrecord([tfrecord_file])
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
target, shape = sess.run([target, shape])
print(target.shape)
coord.request_stop()
coord.join(threads)
有人可以帮助我了解我要去哪里哪里吗?张量如何获得额外的值?