我正在建立一个LSTM网络,从我自己对LSTM细胞如何工作的理解。
没有图层,所以我正在尝试实现教程中看到的非矢量化方程式。我也在使用来自细胞状态的窥视孔。
到目前为止,我知道它看起来像这样:LSTM network
我已经为每个正向通道的门制作了这些方程式:
i_t = sigmoid( i_w * (x_t + c_t) + i_b )
f_t = sigmoid( f_w * (x_t + c_t) + f_b )
cell_gate = tanh( c_w * x_t + c_b )
c_t = (f_t * c_t) + (i_t * cell_gate)
o_t = sigmoid( o_w * (x_t + c_t) + o_b )
h_t = o_t * tanh(c_t)
其中_w是相应门的平均权重,_b是偏差的平均权重。另外,我在最左边的第一个sigmoid上命名为“cell_gate”。
回传就是让事情变得模糊的地方,我不确定如何正确推导出这些方程式。
我一般都知道计算错误,方程是:error = f'(x_t)*(received_error)。其中f'(x_t)是激活函数的一阶导数,received_error可以是输出神经元的(目标 - 输出)或隐藏神经元的Σ(o_e * w_io)。
其中o_e是当前单元格输出的单元格之一的错误,而w_io是连接它们的权重。
我不确定整个LSTM细胞是否被视为神经元,因此我将每个门视为神经元,并尝试计算每个门的误差信号。然后单独使用来自单元门的错误信号来传回网络...:
o_e = sigmoid'(o_w * (x_t + c_t) + o_b) * (received_error)
o_w += o_l * x_t * o_e
o_b += o_l * sigmoid(o_b) * o_e
......其余的门遵循相同的格式......
然后整个LSTM单元的错误等于o_e。
然后对于当前单元格上方的LSTM单元格,它收到的错误等于:
tanh'(x_t) * ∑(o_e * w_io)
这一切都正确吗?我做错了什么?
答案 0 :(得分:0)
我正在接受这项任务,我相信你的方法是正确的:
https://github.com/evolvingstuff/LongShortTermMemory/blob/master/src/com/evolvingstuff/LSTM.java
一些不错的作品来自:Thomas Lahore
//////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////
//BACKPROP
//////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////
//scale partials
for (int c = 0; c < cell_blocks; c++) {
for (int i = 0; i < full_input_dimension; i++) {
this.dSdwWeightsInputGate[c][i] *= ForgetGateAct[c];
this.dSdwWeightsForgetGate[c][i] *= ForgetGateAct[c];
this.dSdwWeightsNetInput[c][i] *= ForgetGateAct[c];
dSdwWeightsInputGate[c][i] += full_input[i] * neuronInputGate.Derivative(InputGateSum[c]) * NetInputAct[c];
dSdwWeightsForgetGate[c][i] += full_input[i] * neuronForgetGate.Derivative(ForgetGateSum[c]) * CEC1[c];
dSdwWeightsNetInput[c][i] += full_input[i] * neuronNetInput.Derivative(NetInputSum[c]) * InputGateAct[c];
}
}
if (target_output != null) {
double[] deltaGlobalOutputPre = new double[output_dimension];
for (int k = 0; k < output_dimension; k++) {
deltaGlobalOutputPre[k] = target_output[k] - output[k];
}
//output to hidden
double[] deltaNetOutput = new double[cell_blocks];
for (int k = 0; k < output_dimension; k++) {
//links
for (int c = 0; c < cell_blocks; c++) {
deltaNetOutput[c] += deltaGlobalOutputPre[k] * weightsGlobalOutput[k][c];
weightsGlobalOutput[k][c] += deltaGlobalOutputPre[k] * NetOutputAct[c] * learningRate;
}
//bias
weightsGlobalOutput[k][cell_blocks] += deltaGlobalOutputPre[k] * 1.0 * learningRate;
}
for (int c = 0; c < cell_blocks; c++) {
//update output gates
double deltaOutputGatePost = deltaNetOutput[c] * CECSquashAct[c];
double deltaOutputGatePre = neuronOutputGate.Derivative(OutputGateSum[c]) * deltaOutputGatePost;
for (int i = 0; i < full_input_dimension; i++) {
weightsOutputGate[c][i] += full_input[i] * deltaOutputGatePre * learningRate;
}
peepOutputGate[c] += CEC3[c] * deltaOutputGatePre * learningRate;
//before outgate
double deltaCEC3 = deltaNetOutput[c] * OutputGateAct[c] * neuronCECSquash.Derivative(CEC3[c]);
//update input gates
double deltaInputGatePost = deltaCEC3 * NetInputAct[c];
double deltaInputGatePre = neuronInputGate.Derivative(InputGateSum[c]) * deltaInputGatePost;
for (int i = 0; i < full_input_dimension; i++) {
weightsInputGate[c][i] += dSdwWeightsInputGate[c][i] * deltaCEC3 * learningRate;
}
peepInputGate[c] += CEC2[c] * deltaInputGatePre * learningRate;
//before ingate
double deltaCEC2 = deltaCEC3;
//update forget gates
double deltaForgetGatePost = deltaCEC2 * CEC1[c];
double deltaForgetGatePre = neuronForgetGate.Derivative(ForgetGateSum[c]) * deltaForgetGatePost;
for (int i = 0; i < full_input_dimension; i++) {
weightsForgetGate[c][i] += dSdwWeightsForgetGate[c][i] * deltaCEC2 * learningRate;
}
peepForgetGate[c] += CEC1[c] * deltaForgetGatePre * learningRate;
//update cell inputs
for (int i = 0; i < full_input_dimension; i++) {
weightsNetInput[c][i] += dSdwWeightsNetInput[c][i] * deltaCEC3 * learningRate;
}
//no peeps for cell inputs
}
}
//////////////////////////////////////////////////////////////
//roll-over context to next time step
for (int j = 0; j < cell_blocks; j++) {
context[j] = NetOutputAct[j];
CEC[j] = CEC3[j];
}
另外,也许更有趣的是Andrej Karpathy的讲座和讲义:
https://youtu.be/cO0a0QYmFm8?t=45m36s
http://cs231n.stanford.edu/slides/2016/winter1516_lecture10.pdf