我在LSTM上非常喜欢使用keras并尝试使用以下解决方案Keras LSTM multiclass classification来使用LSTM进行多类分类。我有10个类中的768维特征向量,并希望使用LSTM对它们进行分类。这是我尝试过的
def do_experiment(train_file, validation_file, test_file, experiment_number, optimizer_name):
def scheduler(epoch):
if epoch % 4 == 0 and epoch:
K.set_value(model.optimizer.lr, K.get_value(model.optimizer.lr)*0.9)
print(K.get_value(model.optimizer.lr))
return K.get_value(model.optimizer.lr)
change_lr = LearningRateScheduler(scheduler)
early_stopper = EarlyStopping(min_delta=0.001, patience=15)
csv_logger = CSVLogger('lstm.csv')
weights_file="trained_model/" + str(experiment_number) + "-weights.h5"
model_checkpoint= ModelCheckpoint(weights_file, monitor="val_loss", save_best_only=True, save_weights_only=True, mode='auto')
x_train, y_train, groundtruth_train= du.loaddata(train_file, experiment_number)
x_validation, y_validation, groundtruth_validation= du.loaddata(validation_file, experiment_number)
batch_size = 32
nb_classes = 10
nb_epoch = 100
model = Sequential()
model.add(Embedding(5000, 32, input_length=768))
model.add(LSTM(100, dropout=0.2, recurrent_dropout=0.2))
model.add(Dense(10, activation='softmax'))
model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer_name, metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=batch_size, epochs=nb_epoch, validation_data=(x_validation, y_validation), shuffle=True, callbacks=[change_lr, early_stopper, csv_logger,model_checkpoint])
但每当我运行此代码时,我都会遇到以下错误:
File "/usr/lib64/python2.7/site-packages/keras/models.py", line 960, in fit
validation_steps=validation_steps)
File "/usr/lib64/python2.7/site-packages/keras/engine/training.py", line 1581, in fit
batch_size=batch_size)
File "/usr/lib64/python2.7/site-packages/keras/engine/training.py", line 1418, in _standardize_user_data
exception_prefix='target')
File "/usr/lib64/python2.7/site-packages/keras/engine/training.py", line 153, in _standardize_input_data
str(array.shape))
ValueError: Error when checking target: expected dense_1 to have shape (None, 1) but got array with shape (61171, 10)
我相信我在这里做了一件非常愚蠢的事,但我无法辨认出来。我该如何更改此代码以对768维向量进行分类?