在训练神经网络之后,让我们说一个多层感知器,在预测时我想将第一层与其他层分开。 为此,我找到的唯一方法是获得正确大小的文件,如下所示:
我循环遍历所有图层,然后将它们添加到两个容器中的一个(first layer
或all the others
),我将使用torch.save
函数保存separetelly。有趣的是,我需要在将每个图层的参数添加到两个容器中的任何一个之前检索每个图层的参数,否则在保存时,两个文件(first layer
和all the other layers
)具有相同的文件大小。 / p>
代码片段比我以前的解释更有帮助:
local function split_model(network)
-- for some reason all the models when saved have the same size
-- if not splitted calling 'getParameters()' first.
first_layer = nn.Sequential()
all_the_rest = nn.Sequential()
for i = 1, network:size() do
local l = network:get(i)
local l_params, _ = l:getParameters()
if i == 1 then
first_layer:add(l)
else
all_the_rest:add(l)
end
end
return first_layer, all_the_rest
end
local first_layer, all_the_rest = split_model(network)
torch.save("checkpoints/mlp.t7", .network)
torch.save("checkpoints/first_layer.t7", first_layer)
torch.save("checkpoints/all_the_rest.t7", all_the_rest)
答案 0 :(得分:0)
同样的问题发布在Google groups,这是Alban Desmaison的回答:
您好,
此行为的原因是
getParameters
的工作方式 https://github.com/torch/nn/blob/master/doc/module.md#flatparameters-flatgradparameters-getparameters 为了能够返回包含所有参数的平坦张量,它 实际上创建一个包含所有权重的存储然后 每个模块的重量都是此存储的一部分。当你保存 网络中任何元素的权重,都必须保存重量 张量和这样做,保存底层存储。因此,如果你打电话getParameters
在整个网络上,如果您保存任何模块,则为您 将保存所有网络权重。在这里,当你打电话getParameters
在单个模块上,它实际上重新创建了这个 单个存储,但对于这个单个模块,因此,当您保存它时 仅包含您想要的权重。但请注意扁平化 您在getParameters
上返回的参数 完整的网络不再有效!!!这里有两个解决方案: - 如果您不想使用来自整个网络
getParameters
的参数,您只需拨打电话即可 在保存之前,在网络的每个子集上getParameters
。 这将打破更改底层存储只包含此内容 网络的子集,您将只保存您需要的东西(共享 存储只存储一次)。 - 如果您希望能够继续使用原始getParameters
中的参数,您可以执行与上述相同但使用克隆版本 他们要做getParameters
并保存。因为代码片段总是更好:
require 'nn'
local subset1 = nn.Linear(2,2)
local subset2 = nn.Linear(2,2)
local network = nn.Sequential():add(subset1):add(subset2)
print("Before getParameters:", subset1.weight:storage():size()) -- 4 elements
network_params,_ = network:getParameters()
print("After getParameters:", subset1.weight:storage():size()) -- 12 elements
subset1.weight:random() -- Change weights to see if linking is still working
print("network_params is valid?", network_params[1] == subset1.weight[1][1]) -- true
-- Keeping network_params valid
local clone_subset1 = subset1:clone()
print("Cloned subset1 before getParameters:", clone_subset1.weight:storage():size()) -- 12 elements
clone_subset1:getParameters()
print("Cloned subset1 after getParameters:", clone_subset1.weight:storage():size()) -- 6 elements (4 weights + 2 bias)
subset1.weight:random() -- Change weights to see if linking is still working
print("network_params is valid?", network_params[1] == subset1.weight[1][1]) -- true
-- Not keeping network_params valid (should be faster)
local clone_subset1 = subset1:clone()
print("subset1 before getParameters:", subset1.weight:storage():size()) -- 12 elements
subset1:getParameters()
print("subset1 after getParameters:", subset1.weight:storage():size()) -- 6 elements (4 weights + 2 bias)
subset1.weight:random() -- Change weights to see if linking is still working
print("network_params is valid?", network_params[1] == subset1.weight[1][1]) -- false