文本预测LSTM神经网络的问题

时间:2019-11-08 10:22:48

标签: python tensorflow machine-learning keras neural-network

我正在尝试使用递归神经网络(LSTM)对来自书籍的数据集进行文本预测。无论我尝试更改图层大小或其他参数有多少,它总是会过拟合。

我一直在尝试更改层数,LSTM层中的单位数,正则化,规范化,batch_size,混洗训练数据/验证数据,将数据集更改为更大。现在,我尝试使用〜140kb txt图书。我还尝试了200kb,1mb,5mb。

创建训练/验证数据:

sequence_length = 30

x_data = []
y_data = []

for i in range(0, len(text) - sequence_length, 1):
    x_sequence = text[i:i + sequence_length]
    y_label = text[i + sequence_length]

    x_data.append([char2idx[char] for char in x_sequence])
    y_data.append(char2idx[y_label])

X = np.reshape(x_data, (data_length, sequence_length, 1))
X = X/float(vocab_length)
y = np_utils.to_categorical(y_data)

# Split into training and testing set, shuffle data
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.2, shuffle=False)

# Shuffle testing set
X_test, y_test = shuffle(X_test, y_test, random_state=0)

创建模型:

model = Sequential()
model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2]), return_sequences=True, recurrent_initializer='glorot_uniform', recurrent_dropout=0.3))
model.add(LSTM(256, return_sequences=True, recurrent_initializer='glorot_uniform', recurrent_dropout=0.3))
model.add(LSTM(256, recurrent_initializer='glorot_uniform', recurrent_dropout=0.3))
model.add(Dropout(0.2))
model.add(Dense(y.shape[1], activation='softmax'))

enter image description here 编译模型:

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

我具有以下特征: enter image description here

我不知道如何解决这种过度拟合问题,因为我正在搜索互联网,尝试了许多尝试,但似乎都没有用。

如何获得更好的结果?这些预测目前似乎并不理想。

1 个答案:

答案 0 :(得分:1)

以下是我接下来要尝试的一些方法。 (我也是业余爱好者。如果我做错了,请纠正我)

  1. 尝试从文本中提取vector representation。尝试使用word2vec,GloVe,FastText,ELMo。提取矢量表示形式,然后将其输入网络。您也可以创建一个embedding layer来帮助您。该blog具有更多信息。
  2. 256个循环单位可能过多。我认为,永远不要从庞大的网络开始。从小开始。看看您是否身体不适。如果是,则放大。
  3. 关闭优化器。我发现亚当倾向于过度适应。我使用rmsprop和Adadelta取得了更好的成功。
  4. 也许,attention is all you need?变压器最近为NLP做出了巨大贡献。也许您可以在网络中尝试implementing simple soft attention mechanism。如果您还不熟悉,这里是nice video series。上面有一个interactive research paper
  5. 在NLP应用程序中,
  6. CNN也pretty dope。尽管从直觉上来说,它们对于文本数据没有任何意义(对大多数人而言)。也许您可以尝试利用它们,堆叠它们,等等。这是guide,介绍如何将其用于句子分类。我知道,您的域名不同。但我认为直觉会延续下去。 :)