TF记录显示的张量形状与序列化并存储在其中的张量不同

时间:2019-04-03 10:32:30

标签: python tensorflow

我有一个处理过的图像张量(例如25x25x1张量),我将其存储在TFRecord文件中。我正在使用tf.Example方法进行序列化并准备TFRecord文件。

当我读回它时,张量显示为(644 ,?),并且无法将其重塑为25x25x1。我想知道这些额外的值如何进入张量。

如何解决此问题?请帮忙。

我在这里遵循Chip Huyen提供的代码示例-https://github.com/chiphuyen/stanford-tensorflow-tutorials/blob/master/2017/examples/09_tfrecord_example.py

这是我用于写入文件的代码。

channels = 1
resize_height = 25
resize_width = 25

target_file = ./some_JPEG_file.jpeg
shape = np.array([resize_height, resize_width,channels],np.int32).tobytes()

def return_image_resized(filename,channels):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_jpeg(image_string)
  image_resized = tf.image.resize_images(image_decoded, [resize_height, resize_width],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  if channels ==1:
    image_resized = tf.image.rgb_to_grayscale(image_resized)
  return image_resized

def write_to_tfrecord(writer, target, shape, tfrecord_file):

    example = tf.train.Example(features=tf.train.Features(feature={
                'target': _bytes_feature(target),
                'shape': _bytes_feature(shape)
                }))
    writer.write(example.SerializeToString())

target = tf.serialize_tensor(return_image_resized(target_file,channels))
writer = tf.python_io.TFRecordWriter(output_file)
write_to_tfrecord(writer, target, shape, output_file)
writer.close()

但是,当我尝试使用以下代码读取文件时,出现错误消息“ InvalidArgumentError:要重塑的输入是具有644个值的张量,但请求的形状为625”

def read_from_tfrecord(filenames):
    tfrecord_file_queue = tf.train.string_input_producer(filenames, name='queue')
    reader = tf.TFRecordReader()
    _, tfrecord_serialized = reader.read(tfrecord_file_queue)


    tfrecord_features = tf.parse_single_example(tfrecord_serialized,
                        features={
                            'target': tf.FixedLenFeature([], tf.string),
                            'shape': tf.FixedLenFeature([], tf.string),                
                        }, name='features')


    target = tf.decode_raw(tfrecord_features['target'], tf.uint8)
    shape = tf.decode_raw(tfrecord_features['shape'], tf.int32)

    #Error is coming from the below line of code when I try to reshape it
    target = tf.reshape(target, shape)

    return target, shape

def read_tfrecord(tfrecord_file):
    target, shape = read_from_tfrecord([tfrecord_file])

    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        target, shape = sess.run([target, shape])
        print(target.shape)
        coord.request_stop()
        coord.join(threads)

有人可以帮助我了解我要去哪里哪里吗?张量如何获得额外的值?

0 个答案:

没有答案