我已经使用keras训练了许多字符级语言模型。当我第一次拟合模型时,可以使用它,但是当我保存模型然后尝试使用它时,会出现值错误:
ValueError: Error when checking input: expected lstm_1_input to have shape (3, 54) but got array with shape (3, 52)
为什么仅在加载模型时出现此错误,如何使用已保存的模型?
我试图在线找到该问题的解决方案,但大多数google结果似乎都与尺寸变化有关。在我看来,所有发生的变化是,保存的模型似乎期望一个热矢量,该热矢量比应有的长两个。
为清楚起见,共有52个字符,因此(3,52)应该正确。
这是我用来构建模型的代码:
model = Sequential()
model.add(LSTM(40, return_sequences=True, input_shape=(x_train.shape[1], x_train.shape[2])))
model.add(LSTM(40, input_shape=(x_train.shape[1], x_train.shape[2])))
model.add(Dense(vocab_size, activation='softmax'))
print(model.summary())
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=1000, verbose=2)
这是我的保存方式:
model.save('n3_Tokeniser.h5')
这是我的加载方式:
model = load_model('n3_Tokeniser.h5')
这是我尝试称呼它的方式:
def generate_seq(model, mapping, seq_length, seed_text, n_chars):
in_text = seed_text
for _ in range(n_chars):
encoded = [mapping[char] for char in in_text]
encoded = pad_sequences([encoded], maxlen=seq_length, truncating='pre')
encoded = to_categorical(encoded, num_classes=len(mapping))
yhat = model.predict_classes(encoded, verbose=0)
out_char = ''
for char, index in mapping.items():
if index == yhat:
out_char = char
break
in_text += char
return in_text
这是完整的追溯:
Traceback (most recent call last):
File "/Users/ade/PycharmProjects/Wbp/LSTM.py", line 263, in <module>
print(generate_seq(model, chardict, pre_characters, '$$$$$$$$$$', 20))
File "/Users/ade/PycharmProjects/Wbp/LSTM.py", line 250, in generate_seq
yhat = model.predict_classes(encoded, verbose=0)
File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/sequential.py", line 267, in predict_classes
proba = self.predict(x, batch_size=batch_size, verbose=verbose)
File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training.py", line 1149, in predict
x, _, _ = self._standardize_user_data(x)
File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training.py", line 751, in _standardize_user_data
exception_prefix='input')
File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training_utils.py", line 138, in standardize_input_data
str(data_shape))
ValueError: Error when checking input: expected lstm_1_input to have shape (3, 54) but got array with shape (3, 52)