我正在尝试在源代码中跟踪torch.nn.NLLLoss
的实现。我在文件torch._C.nll_loss
中的函数nll_loss
中调用了torch.nn.functional
。但我找不到创建_C的地方。
任何人都有这方面的任何信息?
答案 0 :(得分:0)
在PyTorch博客上查看A Tour of PyTorch Internals。相关摘录:
该帖子的PyTorch定义了一个新的火炬包。在这篇文章中,我们将考虑._C模块。这个模块被称为“扩展模块” - 用C语言编写的Python模块。这些模块允许我们定义新的内置对象类型(例如Tensor)并调用C / C ++函数。
._C模块在torch / csrc / Module.cpp中定义。 init_C()/ PyInit__C()函数创建模块并根据需要添加方法定义。该模块被传递给许多不同的__init()函数,这些函数将更多对象添加到模块,注册新类型等。
Part II详细介绍了构建系统。在关于NN模块的部分中,它说
简而言之,让我们来谈谈build_deps命令的最后一部分:generate_nn_wrappers()。我们使用PyTorch的自定义cwrap工具绑定到后端库,我们在之前的帖子中提到过。对于绑定TH和THC,我们手动为每个函数编写YAML声明。但是,由于THNN和THCUNN库的相对简单性,我们自动生成cwrap声明和生成的C ++代码。
我们将THNN.h和THCUNN.h头文件复制到torch / lib的原因是,这是generate_nn_wrappers()代码期望找到这些文件的地方。 generate_nn_wrappers()做了一些事情:
- 解析头文件,生成cwrap YAML声明并将其写入输出.cwrap文件
- 使用这些.cwrap文件中的相应插件调用cwrap以生成每个
的源代码- 第二次解析标题以生成THNN_generic.h - 一个采用THPP张量的文库,PyTorch的“通用”C ++张量库,并根据Tensor的动态类型调用适当的THNN / THCUNN库函数
如果没有上下文,也许没有那么有用,但我认为我不应该在这里复制整篇文章。
当我试图在没有阅读这些帖子的情况下追踪NLLLoss的定义时,我最终在aten/src/THNN/generic/ClassNLLCriterion.c,通过aten/src/ATen/nn.yaml。后者可能是第二篇文章谈到的YAML,但我还没有检查过。
答案 1 :(得分:0)
TL; DR:如果你想了解一个函数是如何用 C 代码实现的,你可以在这里查看 Pytorch 的 github 仓库: