我有一个子类nn.module
的代码。
我不知道确切的reset_()
函数做什么,并且在nn.module源代码中找不到任何reset_()
函数。
谁知道在没有任何操作且父类的名称reset_()
中也没有任何功能的情况下,该如何在神经网络中使用它来重置连接???
class Connection(torch.nn.module):
super().__init__()
def reset_(self) -> None:
#Contains resetting logic for the connection.#
super().reset_()
答案 0 :(得分:0)
尽管我不确定PyTorch模块中的reset()
函数是什么意思,但是,通常在许多NN层中,都有一个reset_parameters()
函数用于重置该参数层。我帮您举个例子。
import torch
import torch.nn as nn
class Connection(nn.Module):
def __init__(self):
super().__init__()
# a weight matrix of shape [10 x 100] as parameters
self.weight = nn.Parameter(torch.Tensor(10, 100))
def reset_parameters(self) -> None:
# reset parameters using random values from a uniform distribution
nn.init.uniform_(self.weight, -0.01, 0.01)
c = Connection()
c.reset_parameters() # reset the weight parameters
这仅仅是一个示例,您可以修改reset_parameters
函数来满足您的需求。