TFSlim ValueError无法挤压暗淡[1],预期维度为1,得到3为' vgg_16 / fc8 / squeezed' (op:' Squeeze')输入形状:[3,3,3,2]

时间:2017-04-05 21:13:58

标签: python tensorflow

尝试在除fc8之外的不同类别标签(2)上微调Tensorflow Slim VGG16网络。执行时我收到此错误。

错误

logits, _ = vgg.vgg_16(images, num_classes=NUM_CLASSES, is_training=True)
/models/slim/nets/vgg.py", line 178, in vgg_16
net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 2273, in squeeze
--- STACK TRACE OMITTED -----
/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 675, in _call_cpp_shape_fn_impl
raise ValueError(err.message)
ValueError: Can not squeeze dim[1], expected a dimension of 1, got 3 for 'vgg_16/fc8/squeezed' (op: 'Squeeze') with input shapes: [3,3,3,2].

代码

 BATCH_SIZE    = 3
 NUM_CLASSES = 2
 def load_batch():
    filepaths, labels = read_label_file(train_labels_file)
    images = ops.convert_to_tensor(filepaths, dtype=dtypes.string)
    labels = ops.convert_to_tensor(labels, dtype=dtypes.int32)
    input_queue = tf.train.slice_input_producer([images, labels],shuffle=False)
    file_content = tf.read_file(input_queue[0])
    image = tf.image.decode_jpeg(file_content, channels=NUM_CHANNELS)
    label = input_queue[1]
    image.set_shape([387,408,3])
    size = tf.constant([224,224],dtype=tf.int32)
    image = tf.image.resize_images(image,size)
    image_batch, label_batch = tf.train.batch([image, label],batch_size=BATCH_SIZE , num_threads=1)
    return image_batch , label_batch

 with tf.Graph().as_default():

    tf.logging.set_verbosity(tf.logging.INFO)
    images,labels = load_batch()
    with slim.arg_scope(vgg.vgg_arg_scope()):
        logits, _ = vgg.vgg_16(images, num_classes=NUM_CLASSES, is_training=True)
    .... 

1 个答案:

答案 0 :(得分:0)

您可以尝试直接定义批次:

with tf.Graph().as_default():

    tf.logging.set_verbosity(tf.logging.INFO)
    images = tf.randon_uniform([BATCH_SIZE, 224, 224, 3])
    labels = tf.randon_uniform([BATCH_SIZE], max_value=NUM_CLASES)
    with slim.arg_scope(vgg.vgg_arg_scope()):
        logits, _ = vgg.vgg_16(images, num_classes=NUM_CLASSES, is_training=True)

您还可以在将张量图像和标签传递给vgg.vgg_16之前调试它们的形状