set_shape使用tensorflow中的TFRecordReader在read_and_decode中引发问题

时间:2017-05-22 20:51:19

标签: tensorflow

我正在尝试使用自己的架构训练一个imagenet分类器(我的项目需要预训练的权重)。我已经预处理了ILSVRC2012的图像以及张量流中初始教程中所解释的所有内容,但我无法通过此read_and_decode函数。问题在于image.set_shape()。有谁知道该怎么办?那么set_shape()的目的是什么?

def read_and_decode(filename_queue):
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)
  features = tf.parse_single_example(
    serialized_example,
    features={
        'image_raw': tf.FixedLenFeature([], tf.string),
        'label': tf.FixedLenFeature([], tf.int64)
    })
  image = tf.decode_raw(features['image_raw'], tf.uint8)
  image = tf.cast(image, tf.float32) * (1. / 255) - 0.5

  # Here comes the error line
  image.set_shape([None, None, 3])

  label = tf.cast(features['label'], tf.int32)

  return image, label

错误日志:

  File "./grasp_detection.py", line 49, in read_and_decode
  image.set_shape([None, None, 3])
  File "/usr/local/lib/python2.7/site-
  packages/tensorflow/python/framework/ops.py", line 425, in set_shape
  self._shape = self._shape.merge_with(shape)
  File "/usr/local/lib/python2.7/site-
  packages/tensorflow/python/framework/tensor_shape.py", line 585, in 
  merge_with
    (self, other))
  ValueError: Shapes (?,) and (?, ?, 3) are not compatible

编辑:已解决 首先,我在没有set_shape的情况下对其进行了编程,但我得到了错误ValueError: All shapes must be fully defined。我知道,在初始教程中,来自tensorflow的所有图像都经过预处理并具有相同的定义形状(对我来说是unkonown)。我认为通过在stackaoverflow中找到形状并使用set_shape可以解决read_and_decode的问题。后来我不得不重新塑造图像以适应我的模型。

追求这一点的自然而最好的方法是重新塑造read_and_decode中的图像,正如评论中所指出的那样。对于所有感兴趣的人来说,工作read_and_decode看起来像这样:

def read_and_decode(filename_queue):
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)
  features = tf.parse_single_example(
      serialized_example,
      features={
          'image_raw': tf.FixedLenFeature([], tf.string),
          'label': tf.FixedLenFeature([], tf.int64)
      })
  image = tf.decode_raw(features['image_raw'], tf.uint8)
  image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
  image_shape = tf.stack([IMAGE_HEIGHT, IMAGE_WIDTH, 3])
  image = tf.reshape(image, image_shape)
  label = tf.cast(features['label'], tf.int32)
  return image, label

非常感谢任何建议或批评。

1 个答案:

答案 0 :(得分:0)

一个可能的问题:

您是否在image.set_shape()中输入了形状?它应该像

image.set_shape([image_height,image_width,nchannels]) or 
image.set_shape([None,None,nchannels])

您可以发布错误日志吗?