Pytorch线性模块类定义中的常量

时间:2019-08-16 10:11:25

标签: pytorch

https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html中定义的火炬__constants__中的class Linear(Module):是什么?

它的功能是什么,为什么要使用它?

我一直在搜索,但是没有找到任何文档。请注意,这并不意味着割炬脚本中的__constants__

1 个答案:

答案 0 :(得分:2)

实际上,您所谈论的__constants__是与TorchScript相关的那个。您可以使用GitHub上的git blame (添加时和添加者)进行确认。例如,对于torch/nn/modules/linear.py,请检查其git blame

  

TorchScript还提供了一种使用Python中定义的常量的方法。这些可用于将超参数硬编码到函数中,或用于定义通用常量。

     

-可以通过将ScriptModule的属性列为该类的 constants 属性的成员来将其标记为常量:

class Foo(torch.jit.ScriptModule):
    __constants__ = ['a']

    def __init__(self):
        super(Foo, self).__init__(False)
        self.a = 1 + 4

   @torch.jit.script_method
   def forward(self, input):
       return self.a + input