PyTorch:使用numpy数组为GRU / LSTM手动设置权重参数

时间:2018-10-23 09:17:56

标签: python lstm pytorch rnn

我正在尝试在pytorch中使用手动定义的参数填充GRU / LSTM。

我有numpy数组,用于具有其文档(https://pytorch.org/docs/stable/nn.html#torch.nn.GRU)中定义的形状的参数。

这似乎可行,但是我不确定返回的值是否正确。

这是用numpy参数填充GRU / LSTM的正确方法吗?

gru = nn.GRU(input_size, hidden_size, num_layers,
              bias=True, batch_first=False, dropout=dropout, bidirectional=bidirectional)

def set_nn_wih(layer, parameter_name, w, l0=True):
    param = getattr(layer, parameter_name)
    if l0:
        for i in range(3*hidden_size):
            param.data[i] = w[i*input_size:(i+1)*input_size]
    else:
        for i in range(3*hidden_size):
            param.data[i] = w[i*num_directions*hidden_size:(i+1)*num_directions*hidden_size]

def set_nn_whh(layer, parameter_name, w):
    param = getattr(layer, parameter_name)
    for i in range(3*hidden_size):
        param.data[i] = w[i*hidden_size:(i+1)*hidden_size]

l0=True

for i in range(num_directions):
    for j in range(num_layers):
        if j == 0:
            wih = w0[i, :, :3*input_size]
            whh = w0[i, :, 3*input_size:]  # check
            l0=True
        else:
            wih = w[j-1, i, :, :num_directions*3*hidden_size]
            whh = w[j-1, i, :, num_directions*3*hidden_size:]
            l0=False

        if i == 0:
            set_nn_wih(
                gru, "weight_ih_l{}".format(j), torch.from_numpy(wih.flatten()),l0)
            set_nn_whh(
                gru, "weight_hh_l{}".format(j), torch.from_numpy(whh.flatten()))
        else:
            set_nn_wih(
                gru, "weight_ih_l{}_reverse".format(j), torch.from_numpy(wih.flatten()),l0)
            set_nn_whh(
                gru, "weight_hh_l{}_reverse".format(j), torch.from_numpy(whh.flatten()))

y, hn = gru(x_t, h_t)

numpy数组定义如下:

rng = np.random.RandomState(313)
w0 = rng.randn(num_directions, hidden_size, 3*(input_size +
               hidden_size)).astype(np.float32)
w = rng.randn(max(1, num_layers-1), num_directions, hidden_size,
              3*(num_directions*hidden_size + hidden_size)).astype(np.float32)

2 个答案:

答案 0 :(得分:4)

这是一个很好的问题,您已经给出了不错的答案。但是,它重新发明了轮子-一个非常优雅的Pytorch内部例程,使您无需花费太多精力即可完成此操作-并且适用于任何网络。

此处的核心概念是PyTorch的{​​{1}}。状态字典有效地包含state_dict,该parametersnn.Modules及其子模块之间的关系给出的树结构组织起来。

这是一个状态字典如何查找GRU的示例(我选择了input_size = hidden_size = 2,以便可以打印整个状态字典)

rnn = torch.nn.GRU(2, 2, 1)
rnn.state_dict()
# Out[10]: 
#     OrderedDict([('weight_ih_l0', tensor([[-0.0023, -0.0460],
#                         [ 0.3373,  0.0070],
#                         [ 0.0745, -0.5345],
#                         [ 0.5347, -0.2373],
#                         [-0.2217, -0.2824],
#                         [-0.2983,  0.4771]])),
#                 ('weight_hh_l0', tensor([[-0.2837, -0.0571],
#                         [-0.1820,  0.6963],
#                         [ 0.4978, -0.6342],
#                         [ 0.0366,  0.2156],
#                         [ 0.5009,  0.4382],
#                         [-0.7012, -0.5157]])),
#                 ('bias_ih_l0',
#                 tensor([-0.2158, -0.6643, -0.3505, -0.0959, -0.5332, -0.6209])),
#                 ('bias_hh_l0',
#                 tensor([-0.1845,  0.4075, -0.1721, -0.4893, -0.2427,  0.3973]))])

因此state_dict是网络的所有参数。如果我们有“嵌套” nn.Modules,我们将得到由参数名称表示的树:

class MLP(torch.nn.Module):      
    def __init__(self):
        torch.nn.Module.__init__(self)
        self.lin_a = torch.nn.Linear(2, 2)
        self.lin_b = torch.nn.Linear(2, 2)


mlp = MLP()
mlp.state_dict()
#    Out[23]: 
#        OrderedDict([('lin_a.weight', tensor([[-0.2914,  0.0791],
#                            [-0.1167,  0.6591]])),
#                    ('lin_a.bias', tensor([-0.2745, -0.1614])),
#                    ('lin_b.weight', tensor([[-0.4634, -0.2649],
#                            [ 0.4552,  0.3812]])),
#                    ('lin_b.bias', tensor([ 0.0273, -0.1283]))])


class NestedMLP(torch.nn.Module):
    def __init__(self):
        torch.nn.Module.__init__(self)
        self.mlp_a = MLP()
        self.mlp_b = MLP()


n_mlp = NestedMLP()
n_mlp.state_dict()
#   Out[26]: 
#        OrderedDict([('mlp_a.lin_a.weight', tensor([[ 0.2543,  0.3412],
#                            [-0.1984, -0.3235]])),
#                    ('mlp_a.lin_a.bias', tensor([ 0.2480, -0.0631])),
#                    ('mlp_a.lin_b.weight', tensor([[-0.4575, -0.6072],
#                            [-0.0100,  0.5887]])),
#                    ('mlp_a.lin_b.bias', tensor([-0.3116,  0.5603])),
#                    ('mlp_b.lin_a.weight', tensor([[ 0.3722,  0.6940],
#                            [-0.5120,  0.5414]])),
#                    ('mlp_b.lin_a.bias', tensor([0.3604, 0.0316])),
#                    ('mlp_b.lin_b.weight', tensor([[-0.5571,  0.0830],
#                            [ 0.5230, -0.1020]])),
#                    ('mlp_b.lin_b.bias', tensor([ 0.2156, -0.2930]))])

那么-如果您不想提取状态dict而是对其进行更改,从而更改网络参数怎么办?使用nn.Module.load_state_dict(state_dict, strict=True)link to the docs) 此方法允许您将具有任意值的整个state_dict加载到相同类型的实例化模型中,只要键(即参数名称)正确且值(即参数)为{ {1}}的形状正确。 如果将torch.tensors kwarg设置为strict(默认值),则加载的dict必须与原始状态dict完全匹配,除了参数的值。也就是说,每个参数必须有一个新值。

对于上面的GRU示例,我们需要为True中的每一个使用正确大小的张量(以及正确的设备,btw)。由于有时我们只想加载一些值(就像我想的那样),我们可以将'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0' kwarg设置为strict-然后我们只能加载部分国家命令,例如一个仅包含False的参数值。

作为实用建议,我只需要创建要向其中加载值的模型,然后打印状态字典(或至少是键列表和各自的张量大小)

'weight_ih_l0'

这告诉您要更改参数的确切名称。然后,您只需使用相应的参数名称和张量创建状态dict,然后加载它:

print([k, v.shape for k, v in model.state_dict().items()])

答案 1 :(得分:0)

如果您想设置一定的权重/偏见(或一些),我喜欢这样做:

model.state_dict()["your_weight_names_here"][:] = torch.Tensor(your_numpy_array)