我正在尝试使用自己的架构训练一个imagenet分类器(我的项目需要预训练的权重)。我已经预处理了ILSVRC2012的图像以及张量流中初始教程中所解释的所有内容,但我无法通过此read_and_decode函数。问题在于image.set_shape()。有谁知道该怎么办?那么set_shape()的目的是什么?
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
# Here comes the error line
image.set_shape([None, None, 3])
label = tf.cast(features['label'], tf.int32)
return image, label
错误日志:
File "./grasp_detection.py", line 49, in read_and_decode
image.set_shape([None, None, 3])
File "/usr/local/lib/python2.7/site-
packages/tensorflow/python/framework/ops.py", line 425, in set_shape
self._shape = self._shape.merge_with(shape)
File "/usr/local/lib/python2.7/site-
packages/tensorflow/python/framework/tensor_shape.py", line 585, in
merge_with
(self, other))
ValueError: Shapes (?,) and (?, ?, 3) are not compatible
编辑:已解决
首先,我在没有set_shape
的情况下对其进行了编程,但我得到了错误ValueError: All shapes must be fully defined
。我知道,在初始教程中,来自tensorflow的所有图像都经过预处理并具有相同的定义形状(对我来说是unkonown)。我认为通过在stackaoverflow中找到形状并使用set_shape
可以解决read_and_decode
的问题。后来我不得不重新塑造图像以适应我的模型。
追求这一点的自然而最好的方法是重新塑造read_and_decode
中的图像,正如评论中所指出的那样。对于所有感兴趣的人来说,工作read_and_decode
看起来像这样:
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
image_shape = tf.stack([IMAGE_HEIGHT, IMAGE_WIDTH, 3])
image = tf.reshape(image, image_shape)
label = tf.cast(features['label'], tf.int32)
return image, label
非常感谢任何建议或批评。
答案 0 :(得分:0)
一个可能的问题:
您是否在image.set_shape()中输入了形状?它应该像
image.set_shape([image_height,image_width,nchannels]) or
image.set_shape([None,None,nchannels])
您可以发布错误日志吗?