这个函数如何:input.nn.MSECriterion_updateOutput(self,input,target)工作(在Lua / Torch中)?

时间:2015-05-22 23:42:37

标签: function lua lua-table torch

我有这个功能:

    function MSECriterion:updateOutput(input, target)
        return input.nn.MSECriterion_updateOutput(self, input, target)
    end

现在,

   input.nn.MSECriterion_updateOutput(self, input, target)

返回一个数字。我不知道它是怎么做到的。我已经在调试器中一步步走了,看起来这只是计算一个没有中间步骤的数字。

 input is a Tensor of size 1 (say, -.234). And the 

 nn.MSECriterion_updateOutput(self, input, target) looks like it is just the function MSECriterion:updateOutput(input, target).

我对如何计算数字感到困惑。

我很困惑为什么甚至允许这样做。参数输入是一个张量,它甚至没有任何名为nn.MSE input.nn.MSECriterion_updateOutput的方法。

1 个答案:

答案 0 :(得分:5)

当您执行require "nn"时,会加载init.lua,而require('libnn')依次执行init.c。这是torch / nn的C扩展名。

如果查看libnn.so,您可以找到luaopen_libnn:这是require MSECriterion时调用的初始化函数。

此功能负责初始化torch / nn的所有部分,包括nn_FloatMSECriterion_init(L)的{​​{1}}和nn_DoubleMSECriterion_init(L)的原生部分。

如果您查看generic/MSECriterion.c,就可以找到通用(即针对floatdouble展开的宏)initialization function

static void nn_(MSECriterion_init)(lua_State *L)
{
  luaT_pushmetatable(L, torch_Tensor);
  luaT_registeratname(L, nn_(MSECriterion__), "nn");
  lua_pop(L,1);
}

此init函数修改任何torch.FloatTensortorch.DoubleTensor的元表,以便在nn键下填充一堆函数(有关详细信息,请参阅Torch7 Lua C API )。这些函数在之前定义:

static const struct luaL_Reg nn_(MSECriterion__) [] = {
  {"MSECriterion_updateOutput", nn_(MSECriterion_updateOutput)},
  {"MSECriterion_updateGradInput", nn_(MSECriterion_updateGradInput)},
  {NULL, NULL}
};

换句话说,任何张量都有这些函数附加,这要归功于它的metatable:

luajit -lnn
> print(torch.Tensor().nn.MSECriterion_updateOutput)
function: 0x40921df8
> print(torch.Tensor().nn.MSECriterion_updateGradInput)
function: 0x40921e20

注意:对于具有C本机实现对应的所有torch / nn模块,此机制是相同的。

所以input.nn.MSECriterion_updateOutput(self, input, target)可以调用static int nn_(MSECriterion_updateOutput)(lua_State *L),因为您可以在generic/MSECriterion.c上看到。

此函数计算输入张量之间的均方误差。