我正在使用https://github.com/wojzaremba/lstm
中实施的LSTM语言模型它使用以下lstm函数
local function lstm(x, prev_c, prev_h)
-- Calculate all four gates in one go
local i2h = nn.Linear(params.rnn_size, 4*params.rnn_size)(x)
local h2h = nn.Linear(params.rnn_size, 4*params.rnn_size)(prev_h)
local gates = nn.CAddTable()({i2h, h2h})
-- Reshape to (batch_size, n_gates, hid_size)
-- Then slize the n_gates dimension, i.e dimension 2
local reshaped_gates = nn.Reshape(4,params.rnn_size)(gates)
local sliced_gates = nn.SplitTable(2)(reshaped_gates)
-- Use select gate to fetch each gate and apply nonlinearity
local in_gate = nn.Sigmoid()(nn.SelectTable(1)(sliced_gates))
local in_transform = nn.Tanh()(nn.SelectTable(2)(sliced_gates))
local forget_gate = nn.Sigmoid()(nn.SelectTable(3)(sliced_gates))
local out_gate = nn.Sigmoid()(nn.SelectTable(4)(sliced_gates))
local next_c = nn.CAddTable()({
nn.CMulTable()({forget_gate, prev_c}),
nn.CMulTable()({in_gate, in_transform})
})
local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)})
return next_c, next_h
end
在以下网络中使用了哪个(我删除了softmax和条件图层,并将它们分别添加到代码中的其他位置)
local function create_network()
local x = nn.Identity()()
local prev_s = nn.Identity()()
local i = {[0] = x}
local next_s = {}
local split = {prev_s:split(2 * params.layers)}
for layer_idx = 1, params.layers do
local prev_c = split[2 * layer_idx - 1]
local prev_h = split[2 * layer_idx]
local dropped = nn.Dropout(params.dropout)(i[layer_idx - 1])
local next_c, next_h = lstm(dropped, prev_c, prev_h)
table.insert(next_s, next_c)
table.insert(next_s, next_h)
i[layer_idx] = next_h
end
local res = nn.Identity()(i[params.layers])
local module = nn.gModule({x, prev_s},
{res, nn.Identity()(next_s)})
return module
end
上述网络返回网络的输出和将在下一次迭代中使用的lstm层的状态。对于2层lstm网络,状态按以下顺序保存在表中{cell_1,output_1,cell_2,output_2}。网络输出和输出_2是相同的。
我有两个问题: (1)当我在这个网络上应用前向和后向传播时,状态的梯度是如何排列的?它们是否与上表具有相同的顺序,或者它们是否会像这样反转:{grad_cell_2,grad_output_2,grad_cell_1,grad_output_1}
我最初认为它们与输出表的顺序相同,但我有理由怀疑顺序是相反的(基于我手动设置每次迭代的渐变的一些测试)。我不确定,但我不知道如何调试此代码以确切知道发生了什么。
(2)在后退步骤中,如果我只知道输出的梯度(与状态表中的最后一个条目相同),我应该传递输出(res)或状态表的渐变( next_s)还是两者兼而有之?我认为仅将渐变传递给输出或仅传递状态表的最后一个条目会给我完全相同的结果,因为输出只是表中的最后一个条目。但是,当我尝试两种方式时,我会得到不同的结果。