输入通道的数量与Keras

时间:2017-08-27 21:33:31

标签: machine-learning neural-network keras deep-learning conv-neural-network

我使用keras构建基于Resnet50的模型,下面的代码如下所示

input_crop = Input(shape=(3, 224, 224))

# extract feature from image crop
resnet = ResNet50(include_top=False, weights='imagenet')
for layer in resnet.layers:  # set resnet as non-trainable
    layer.trainable = False

crop_encoded = resnet(input_crop)  

但是,我收到了错误

  

'ValueError:输入通道的数量与对应的不匹配   过滤器的尺寸,224!= 3'

我该如何解决?

1 个答案:

答案 0 :(得分:4)

由于Theano& amp; amp; amp; amp; amp; amp; amp; amp; amp; and TrasorFlow backends代表Keras。在您的情况下,图像显然是channels_first格式(Theano),而大多数情况下您使用的是TensorFlow后端,需要channels_last格式。

Keras的MNIST CNN example提供了一种很好的方法来让你的代码免受这些问题的影响,即同时为Theano& TensorFlow后端 - 这是对您的数据的修改:

from keras import backend as K

img_rows, img_cols = 224, 224

if K.image_data_format() == 'channels_first':
    input_crop = input_crop.reshape(input_crop.shape[0], 3, img_rows, img_cols)
    input_shape = (3, img_rows, img_cols)
else:
    input_crop = input_crop.reshape(input_crop.shape[0], img_rows, img_cols, 3)
    input_shape = (img_rows, img_cols, 3)

input_crop = Input(shape=input_shape)