如何使用从LSTM递归神经网络中提取的权重

时间:2018-03-23 11:02:33

标签: python tensorflow machine-learning neural-network keras

我已经用Python中的Keras训练了LSTM递归神经网络用于序列(时间序列)分类。

要素按形状进行整理(batch_size,timesteps,data_dim)。我的训练样例共有1000个。最终目标是在5个班级中进行分类。这是我的代码片段。

#defining some model features
data_dim = 15
timesteps = 20
num_classes = len(one_hot_train_labels[1,:])
batch_size = len(ytrain) 

#reshaping array for LSTM training
xtrain=numpy.reshape(xtrain, (len(ytrain), timesteps, data_dim))
xtest=numpy.reshape(xtest, (len(ytest), timesteps, data_dim))

rms = optimizers.RMSprop(lr=0.001, rho=0.9, epsilon=None, decay=0.0) #It is recommended to leave the parameters
#of this optimizer at their default values (except the learning rate, which can be freely tuned).

# create the model
model = Sequential()
model.add(LSTM(101, dropout=0.5, recurrent_dropout=0.5, input_shape=(timesteps, data_dim), activation='tanh'))
model.add(Dense(5, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer=rms, metrics=['accuracy'])
print(model.summary())
history = model.fit(xtrain, one_hot_train_labels, epochs=200, batch_size=10)
# Final evaluation of the model
scores = model.evaluate(xtrain, one_hot_train_labels, verbose=0)
print("Accuracy: %.2f%%" % (scores[1]*100))
scores = model.evaluate(xtest, one_hot_test_labels, verbose=0)
print("Accuracy: %.2f%%" % (scores[1]*100))

由于我想在别处使用和实现分类器,我使用以下方法提取了权重:

weights = [model.get_weights() for layer in model.layers]

过去使用传统的神经网络和逻辑回归,我期望每层有2个矩阵,一个用多项式权重,一个用偏差单位,然后使用激活函数(在这种情况下是tanh和softmax函数) )逐步找到落入5个班级之一的概率。

但我现在感到困惑,因为调用权重会返回5个具有以下大小的矩阵:

  • (15,400)
  • (100,400)
  • (400)
  • (100,5)
  • (5)

现在,我理解LSTM可以使用4个不同的块:

  1. 从矢量输入
  2. 上一个块的内存
  3. 来自当前块的内存
  4. 上一个块的输出
  5. 因此为什么我的矩阵的2n维度的大小为400。

    现在我的问题是:

    如何使用激活函数(如传统神经网络)以级联方式最终获得类概率?

    输入图层的偏置单位在哪里?

    感谢大家帮忙澄清并帮助理解如何将这个强大的工具用作LSTM网络。

    希望这不仅对我有帮助。

1 个答案:

答案 0 :(得分:0)

当你说获得类概率时,我猜你想要类预测(?)你可以在训练网络后使用model.predict()来获得类概率。您希望预测时最好是model.save_weights(filename)然后model.load_weights(filename)。输入图层没有偏差,您可以看到图层与model.summary()

有多少参数