PyTorch:@weak_script_method装饰器做什么?

时间:2019-02-15 22:47:15

标签: python decorator pytorch

torch.nn.Linear类(以及其他类)中,forward方法包括一个@weak_script_method装饰器,如下所示:

@weak_script_method
def forward(self, input):
    return F.linear(input, self.weight, self.bias)

这个装饰器做什么? 如果我要在forward模块自己的子类中覆盖Linear方法,应该包括它吗?

1 个答案:

答案 0 :(得分:1)

您可以找到确切的decorator location来了解想法。

def weak_script_method(fn):
    weak_script_methods[fn] = {
        "rcb": createResolutionCallback(frames_up=2),
        "original_method": fn
    }
return fn

但是,您不必担心该装饰器。此装饰器位于JIT的内部。

技术上用@weak_script_method装饰的方法将被添加到前面创建的weak_script_methods字典中,如下所示:

weak_script_methods = weakref.WeakKeyDictionary() 

该指令跟踪避免循环依赖问题的方法;方法在创建PyTorch图时调用其他方法。


除非您大致了解TorchScript的概念,否则这确实没有多大意义。

TorchScript的想法是在PyTorch中训练模型并将模型导出到另一个支持静态类型的非Python生产环境(阅读:C ++ / C / Cuda)。

PyTorch团队在有限的Python基础上制作了TorchScript以支持静态类型。 默认情况下,Python是动态类型的语言,但是很少有技巧(阅读:检查),它可以成为静态类型的语言。

因此,TorchScript函数是Python的静态类型子集,其中包含PyTorch的所有内置Tensor操作。这种差异使得TorchScript模块代码无需使用Python解释器即可运行。

您可以使用跟踪(torch.jit.trace()方法)将现有的PyTorch方法转换为TorchScript,也可以使用@torch.jit.script装饰器手动创建TorchScripts。

如果使用跟踪,则最后将得到一个类模块。这是示例:

import inspect

import torch
def foo(x, y):
    return x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

print(type(traced_foo)) #<class 'torch.jit.TopLevelTracedModule'>
print(traced_foo) #foo()
print(traced_foo.forward) #<bound method TopLevelTracedModule.forward of foo()>

lines = inspect.getsource(traced_foo.forward)
print(lines)

输出:

<class 'torch.jit.TopLevelTracedModule'>
foo()
<bound method TopLevelTracedModule.forward of foo()>
    def forward(self, *args, **kwargs):
        return self._get_method('forward')(*args, **kwargs)

您可以使用检查模块进一步调查。这只是展示如何使用跟踪转换一个函数。