我正在学习深度学习。我尝试转移学习,因为我使用vgg16模型。但是,我面对error: Shapes (None, 1) and (None, 2) are incompatible
。我不知道为什么不兼容。帮我。
对不起,我的英语说得不好。但我想知道为什么会出错。
我的代码。
我已经知道,如果我使用sigmod
(激活)可以对数据进行分类。但是我想对三个或三个以上(狗,猫,马,老鼠.....)进行分类,所以我使用softmax
。帮帮我。
ValueError: Shapes (None, 1) and (None, 2) are incompatible
问题出在哪里?
def save_bottlebeck_features():
datagen = ImageDataGenerator(rescale=1. / 255)
# build the VGG16 network
model = applications.VGG16(include_top=False, weights='imagenet')
generator = datagen.flow_from_directory(
train_data_dir,
target_size=(150, 150),
batch_size=batch_size,
class_mode='categorical',
shuffle=False)
bottleneck_features_train = model.predict_generator(
generator)
np.save('bottleneck_features_train.npy',bottleneck_features_train)
generator = datagen.flow_from_directory(
validation_data_dir,
target_size=(150, 150),
batch_size=batch_size,
class_mode='categorical',
shuffle=False)
bottleneck_features_validation = model.predict_generator(
generator)
np.save('bottleneck_features_validation.npy',bottleneck_features_validation)
def train_top_model():
train_data = np.load('bottleneck_features_train.npy')
train_labels = np.array(
[0] * 682 + [1] * 403) # dog: 682 cat : 403
validation_data = np.load('bottleneck_features_validation.npy')
validation_labels = np.array(
[0] * 63 + [1] * 70 )
model = Sequential()
model.add(Flatten(input_shape=train_data.shape[1:]))
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(2, activation='softmax'))
model.summary()
model.compile(optimizer='adam',
loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(train_data, train_labels,
epochs=epochs,
steps_per_epoch=1000 // batch_size,
validation_data=(validation_data, validation_labels))
model.save_weights(top_model_weights_path)
答案 0 :(得分:0)
您遇到的问题是您以[0]和[1]的形式创建了基本事实。
但是,您使用的损失函数为categorical_crossentropy
,在这种情况下,期望输入目标是二维数组(n个类=> n维),而不是一维数组。 / p>
实际上,如果图片属于猫,那么您的网络期望[0,1]作为基本事实,如果图片是狗,则期望[1,0,1]作为事实。
但是,您只输入[0]和[1]而不是[0,1]或[1,0]。
您的问题的解决方案是:
tf.keras.utils.to_categorical()
或keras.utils.to_categorical()
。sparse_categorical_crossentropy
作为损失函数,该函数可让您使用纯整数,例如0,1,2,3作为标签。