如何在PyTorch中更新神经网络的参数?

时间:2018-03-23 09:52:37

标签: pytorch

假设我想将{em> PyTorch (一个继承自torch.nn.Module的类的实例)的神经网络的所有参数乘以0.9。我该怎么做?

2 个答案:

答案 0 :(得分:2)

net成为神经网络类的一个实例。然后你可以做

state_dict = net.state_dict()

for name, param in state_dict.items():
    # Transform the parameter as required.
    transformed_param = param * 0.9

    # Update the parameter.
    state_dict[name].copy_(transformed_param)

将所有参数乘以0.9

如果你只想更新权重而不是所有参数,你可以做

state_dict = net.state_dict()

for name, param in state_dict.items():
    # Don't update if this is not a weight.
    if not "weight" in name:
        continue

    # Transform the parameter as required.
    transformed_param = param * 0.9

    # Update the parameter.
    state_dict[name].copy_(transformed_param)

答案 1 :(得分:0)

实现此目的的另一种方法是使用 <fieldset> <legend>Input Area</legend> <br><br> <label for="week">Week Number:</label> <input type="number" id="week" maxlength="2" size="2" value="0"> <label for="fname">First name:</label> <input type="text" id="firstname" name="fname"> <label for="lname">Last name:</label> <input type="text" id="lastname" name="lname"> <label for="studentnumber">Student Number:</label> <input type="number" id="studentnumber" name="number"> <button id = "button">Generate Lotto Tickets</button> </fieldset> <fieldset> <legend id="display">Display Area</legend> <label for="title">Module Title:</label> <output id="wTitle"><i>module title</i></output><br><br> <label for="fname">Student info:</label> <output id="studentinfo"><i>Student info</i></output> <label for="date">Current Date:</label> <output id="cdate"><i>Current Date</i></output> </fieldset>

初始化模块:

tensor.parameters()

更改参数:

>>> a = torch.nn.Linear(2, 2)
>>> a.state_dict()
OrderedDict([('weight',
              tensor([[-0.1770, -0.2151],
                      [-0.6543,  0.6637]])),
             ('bias', tensor([-0.0524,  0.6807]))])

看效果:

for p in a.parameters():
    p.data *= 0