在torch.nn.Linear
类(以及其他类)中,forward
方法包括一个@weak_script_method
装饰器,如下所示:
@weak_script_method
def forward(self, input):
return F.linear(input, self.weight, self.bias)
这个装饰器做什么? 如果我要在forward
模块自己的子类中覆盖Linear
方法,应该包括它吗?
答案 0 :(得分:1)
您可以找到确切的decorator location来了解想法。
def weak_script_method(fn):
weak_script_methods[fn] = {
"rcb": createResolutionCallback(frames_up=2),
"original_method": fn
}
return fn
但是,您不必担心该装饰器。此装饰器位于JIT的内部。
技术上用@weak_script_method
装饰的方法将被添加到前面创建的weak_script_methods
字典中,如下所示:
weak_script_methods = weakref.WeakKeyDictionary()
该指令跟踪避免循环依赖问题的方法;方法在创建PyTorch图时调用其他方法。
除非您大致了解TorchScript的概念,否则这确实没有多大意义。
TorchScript的想法是在PyTorch中训练模型并将模型导出到另一个支持静态类型的非Python生产环境(阅读:C ++ / C / Cuda)。
PyTorch团队在有限的Python基础上制作了TorchScript以支持静态类型。 默认情况下,Python是动态类型的语言,但是很少有技巧(阅读:检查),它可以成为静态类型的语言。
因此,TorchScript函数是Python的静态类型子集,其中包含PyTorch的所有内置Tensor操作。这种差异使得TorchScript模块代码无需使用Python解释器即可运行。
您可以使用跟踪(torch.jit.trace()
方法)将现有的PyTorch方法转换为TorchScript,也可以使用@torch.jit.script
装饰器手动创建TorchScripts。
如果使用跟踪,则最后将得到一个类模块。这是示例:
import inspect
import torch
def foo(x, y):
return x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
print(type(traced_foo)) #<class 'torch.jit.TopLevelTracedModule'>
print(traced_foo) #foo()
print(traced_foo.forward) #<bound method TopLevelTracedModule.forward of foo()>
lines = inspect.getsource(traced_foo.forward)
print(lines)
输出:
<class 'torch.jit.TopLevelTracedModule'>
foo()
<bound method TopLevelTracedModule.forward of foo()>
def forward(self, *args, **kwargs):
return self._get_method('forward')(*args, **kwargs)
您可以使用检查模块进一步调查。这只是展示如何使用跟踪转换一个函数。