我需要在keras
中构建一个初始模块,并以此来训练cifar100
数据集。
使用以下代码:
from keras.datasets import cifar100
(X_train, y_train), (X_test, y_test) = cifar100.load_data()
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train = X_train / 255.0
X_test = X_test / 255.0
from keras.utils import np_utils
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
from keras.layers import Input
input_img = Input(shape = (32, 32,3))
import keras
from keras.layers import Conv2D, MaxPooling2D
tower_1 = Conv2D(64, (1,1), padding='same', activation='relu')(input_img)
tower_1 = Conv2D(64, (3,3), padding='same', activation='relu')(tower_1)
tower_2 = Conv2D(64, (1,1), padding='same', activation='relu')(input_img)
tower_2 = Conv2D(64, (5,5), padding='same', activation='relu')(tower_2)
tower_3 = MaxPooling2D((3,3), strides=(1,1), padding='same')(input_img)
tower_3 = Conv2D(64, (1,1), padding='same', activation='relu')(tower_3)
output = keras.layers.concatenate([tower_1, tower_2, tower_3], axis = 3)
from keras.layers import Flatten, Dense
output = Flatten()(output)
out = Dense(100, activation='softmax')(output)
from keras.models import Model
model = Model(inputs = input_img, outputs = out)
#print model.summary()
from keras.optimizers import SGD
epochs = 30
lrate = 0.01
decay = lrate/epochs
sgd = SGD(lr=lrate, momentum=0.9, decay=decay, nesterov=False)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=epochs, batch_size=32)
# scores = model.evaluate(X_test, y_test, verbose=0)
#print("Accuracy: %.2f%%" % (scores[1]*100))
我收到以下值错误
ValueError: Error when checking target: expected dense_1 to have shape (10,) but got array with shape (100,)
/opt/conda/lib/python3.6/site-packages/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
136 ': expected ' + names[i] + ' to have shape ' +
137 str(shape) + ' but got array with shape ' +
--> 138 str(data_shape))
139 return data
140
使用kaggle内核脚本。 有人能告诉我什么地方出问题了吗?