实现火炬的__len__元函数

时间:2016-06-03 16:05:34

标签: lua torch

在我们的torch-dataframe项目中,我们尝试按以下方式实施__len__元函数:

MyClass.__len__ = argcheck{
    {name="self", type="MyClass"},
    {name="other", type="MyClass"},
    call=function(self, other)
    return self.n_rows
end}

这适用于Lua 5.2和5.3但是对于Lua 5.1,luajit 2.0和2.1,返回的变量不是实际的行号而是0.感觉是它返回MyClass的新实例,但很难理解为什么。有关__len更改here的说明,但这是我们迄今为止找到的最佳文档提示。

有点令人惊讶的是需要两个论点。当argcheck提供单个参数版本时:

MyClass.__len__ = argcheck{
   {name = "self", type = "MyClass"},
   call=function(self)
   return self.n_rows
end}
它扔了:

[string "argcheck"]:28: 
Arguments:

({
   self = MyClass  -- 
})


Got: MyClass, MyClass

我们目前依赖argcheck重载运算符来处理:

MyClass.__len__ = argcheck{
    {name="self", type="MyClass"},
    {name="other", type="MyClass"},
    call=function(self, other)
    return self.n_rows
end}

MyClass.__len__ = argcheck{
    overload=MyClass.__len__,
    {name="self", type="MyClass"},
    call=function(self)
    return self.n_rows
end}

有关详细信息,请参阅完整课程和特拉维斯报告:

测试用例

这是一个完整的测试用例,它在5.2和5.3中按预期工作,或许可以更简洁的方式说明完整包的问题:

require 'torch'
local argcheck = require "argcheck"

local MyClass = torch.class("MyClass")

function MyClass:init()
    self.n_rows = 0
end

MyClass.__len__ = argcheck{
    {name = "self", type = "MyClass"},
    {name = "other", type = "MyClass"},
    call=function(self, other)
    print(self.n_rows)
    print(other.n_rows)
    return(self.n_rows)
end}

local obj = MyClass.new()
obj.n_rows = 1
local n = #obj
print(n)

按预期打印:

1
1
1

1 个答案:

答案 0 :(得分:0)

该问题与this SO question有关。在5.1中只有no support

  表中的

__ len计划在5.2中得到支持。见LuaFiveTwo。