Keras SimpleRNN

时间:2019-05-02 20:46:01

标签: python machine-learning keras tf.keras

我正在尝试对Keras进行分类。我有1043个字,表示为一键编码矩阵(20个字母长,每个字母26种可能性)。每个都属于19个不同类别之一。

if (has_post_thumbnail())
{
    echo '<img src="'.get_the_post_thumbnail_url(null, 'full').'?size=10" alt="" />';
}
X.shape >>>>>> (1043, 20, 26)

这是我尝试建立模型的尝试。

Y.shape >>>>>> (1043, 19)

这崩溃了:model = Sequential() model.add(SimpleRNN(50, input_shape=(20, 26), return_sequences=True)) model.add(Dense(40, activation='relu')) model.add(Dense(num_categories, activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer='adam') model.fit(X, Y, epochs=20, batch_size=5, verbose=1)

我感觉ValueError: Error when checking target: expected dense_91 to have 3 dimensions, but got array with shape (1043, 19)字段中缺少明显的内容,或者还有其他配置技巧吗?我也找不到在线此类问题的任何清晰示例。


更新:我怀疑我需要将时间片分解成一个最终答案,但是我不确定该怎么做。 input_shape似乎在正确的轨道上,但我无法正常工作。

1 个答案:

答案 0 :(得分:2)

Flatten()输出之后立即添加密集层之前,先添加RNN层。因为您有return_sequences=True,所以在将3维(batch_size, timesteps, 50)张量发送到密集层时,keras会从摘要序列的每个时间步中释放所有隐藏状态,从而导致错误。

>>> model = Sequential()
>>> model.add(SimpleRNN(50, input_shape=(20, 26), return_sequences=True))
>>> from keras.layers import Flatten
>>> model.add(Flatten())
>>> model.add(Dense(40, activation='relu'))
>>> model.add(Dense(num_categories, activation='softmax'))
>>> model.compile(loss='categorical_crossentropy', optimizer='adam')
>>> model.fit(X, Y, epochs=20, batch_size=5, verbose=1)
1043/1043 [==============================] - 3s 3ms/step - loss: -0.0735

但是,我不建议您将return_sequences设置为True,而应该不包含该参数,而是直接将其放入密集层。您没有做seq2seq问题-return_sequences最常使用的问题。相反,

>>> model = Sequential()
>>> model.add(SimpleRNN(50, input_shape=(20, 26)))
>>> model.add(Dense(40, activation='relu'))
>>> model.add(Dense(num_categories, activation='softmax'))
>>> model.compile(loss='categorical_crossentropy', optimizer='adam')
>>> model.fit(X, Y, epochs=20, batch_size=5, verbose=1)
Epoch 1/20
 910/1043 [=========================>....] - ETA: 0s - loss: -0.3609

最后的建议是使用GRU之类的不同RNN模型,并同时使用Embedding层和诸如GLoVE之类的经过预训练的单词向量。不使用预训练的单词嵌入会导致对小数据集的验证性能不佳。您可以看到this SO answer来帮助使用这些嵌入。您可能还想查看keras' functional API-更好。