如何将未知维数的张量馈送到网络?

时间:2018-08-19 09:19:03

标签: python tensorflow keras

感谢您阅读这篇长文章。我在后端使用带有Tensorflow的Keras。我正在尝试训练Unet进行对象检测任务。就我而言,我有不同大小的图像,因此,如果我必须调整它们的大小,则无法保持宽高比。由于Unet是完全卷积的网络,因此我可以使用任意大小的输入进行训练。但是我认为我无法将这些数据提供给网络。与数据馈送有关的所有材料都使用固定大小的输入。这是我制作tfrecord的代码。

def _convert_to_example(filename, image_buffer,mask, height, width):
  """Build an Example proto for an example.
  Args:
    filename: string, path to an image file, e.g., '/path/to/example.JPG'
    image_buffer: string, JPEG encoding of RGB image
    label: integer, identifier for the ground truth for the network
    text: string, unique human-readable, e.g. 'dog'
    height: integer, image height in pixels
    width: integer, image width in pixels
  Returns:
    Example proto
  """

  colorspace = 'RGB'
  channels = 3
  image_format = 'JPEG'

  example = tf.train.Example(features=tf.train.Features(feature={
      'image/height': _int64_feature(height),
      'image/width': _int64_feature(width),
      'image/colorspace': _bytes_feature(tf.compat.as_bytes(colorspace)),
      'image/channels': _int64_feature(channels),
      'image/format': _bytes_feature(tf.compat.as_bytes(image_format)),
      'image/filename': _bytes_feature(tf.compat.as_bytes(os.path.basename(filename))),
      'mask/encoded':_bytes_feature(tf.compat.as_bytes(mask)),
      'image/encoded': _bytes_feature(tf.compat.as_bytes(image_buffer))}))
  return example

我使用此代码从原始文件中获取图像

    def getImage(filename):
        # convert filenames to a queue for an input pipeline.
     filenameQ = tf.train.string_input_producer([filename],num_epochs=None)

        # object to read records
     recordReader = tf.TFRecordReader()

        # read the full set of features for a single example 
     key, fullExample = recordReader.read(filenameQ)

        # parse the full example into its' component features.
     features = tf.parse_single_example(
          fullExample,
          features={
                 'image/height': tf.FixedLenFeature([], tf.int64),
                 'image/width': tf.FixedLenFeature([], tf.int64),
                 'image/colorspace': tf.FixedLenFeature([], dtype=tf.string,default_value=''),
                 'image/channels':  tf.FixedLenFeature([], tf.int64),
                 'image/format': tf.FixedLenFeature([], dtype=tf.string,default_value=''),
                 'image/filename': tf.FixedLenFeature([], dtype=tf.string,default_value=''),
                 'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
                 'mask/encoded':tf.FixedLenFeature([], dtype=tf.string, default_value='')
             })


        # now we are going to manipulate the label and image features

     label = features['mask/encoded']
     image_buffer = features['image/encoded']
     width= tf.cast(features['image/width'], tf.int32)
     height=tf.cast(features['image/height'], tf.int32)
     with tf.name_scope('decode_jpeg',[image_buffer], None):
        # decode
      image = tf.image.decode_jpeg(image_buffer, channels=3)
      label=tf.image.decode_jpeg(label, channels=1)

        # and convert to single precision data type
      image = tf.image.convert_image_dtype(image, dtype=tf.float32)
      label = tf.image.convert_image_dtype(label, dtype=tf.float32)

    # cast image into a single array, where each element corresponds to the greyscale
    # value of a single pixel. 
    # the "1-.." part inverts the image, so that the background is black.
 image=tf.reshape(image,[height,width,3])
 label=tf.reshape(label,[height,width,1])
 return label, image

当我打印image张量时,我看到的形状为

Tensor("Reshape:0", shape=(?, ?, 3), dtype=float32)

据此,我认为它显示了吗?因为张量的大小不同。之后,当我运行

label, image=getImage('train-00000-of-00001')
tlabel, timage=getImage('test-00000-of-00001')
x_train_batch, y_train_batch = tf.train.shuffle_batch(
    tensors=[image,label],
    batch_size=batch_size,
    capacity=capacity,
    min_after_dequeue=min_after_dequeue,
    enqueue_many=enqueue_many,
    num_threads=8)

我遇到错误

  

ValueError:必须完全定义所有形状:   [TensorShape([Dimension(None),Dimension(3)]),   TensorShape([Dimension(None),Dimension(1)])

问题:
我是否准备数据错误或张量流不允许我的数据输入网络?

0 个答案:

没有答案