用python重置神经网络中的连接

时间:2019-10-05 08:02:50

标签: neural-network pytorch

我有一个子类nn.module的代码。

我不知道确切的reset_()函数做什么,并且在nn.module源代码中找不到任何reset_()函数。

谁知道在没有任何操作且父类的名称reset_()中也没有任何功能的情况下,该如何在神经网络中使用它来重置连接???


    class Connection(torch.nn.module):
      super().__init__()

      def reset_(self) -> None:

      #Contains resetting logic for the connection.#

      super().reset_()

1 个答案:

答案 0 :(得分:0)

尽管我不确定PyTorch模块中的reset()函数是什么意思,但是,通常在许多NN层中,都有一个reset_parameters()函数用于重置该参数层。我帮您举个例子。

import torch
import torch.nn as nn


class Connection(nn.Module):

    def __init__(self):
        super().__init__()
        # a weight matrix of shape [10 x 100] as parameters
        self.weight = nn.Parameter(torch.Tensor(10, 100))

    def reset_parameters(self) -> None:
        # reset parameters using random values from a uniform distribution
        nn.init.uniform_(self.weight, -0.01, 0.01)


c = Connection()
c.reset_parameters() # reset the weight parameters

这仅仅是一个示例,您可以修改reset_parameters函数来满足您的需求。