为Torch类设置__index

时间:2016-11-09 20:32:25

标签: lua torch

是否可以为火炬类设置__index方法?我试图实现一个简单的dataset课程,如“深度学习与火炬”中所述:(ipynb here

trainset = {
    inputs = {0, 1, 1, 0},
    targets = {1, 1, 1, 0}
}

index = function(t, i)
    return {t.inputs[i], t.targets[i]}
end

setmetatable(trainset, {
    __index = index
)

允许您执行返回trainset[1]]的{​​{1}}。

但是,将此作为火炬类实现不起作用。

{0, 1}

似乎在创建对象时,local torch = require("torch") do Dataset = torch.class("Dataset") function Dataset:__init(i, t) self.inputs = i self.targets = t end function Dataset.__index(t, v) print("inside index") return { rawget(t, inputs)[v], rawget(t, targets)[v] } end end Dataset({0, 1, 1, 0}, {1, 1, 1, 0}) -- fails 被调用并失败,因为尚未创建__index()index。如果未使用targets,则会导致堆栈溢出。

我对Lua的理解是有限的,但我很惊讶在对象创建过程中看到rawget被调用:我认为幕后有些东西我不完全理解。

1 个答案:

答案 0 :(得分:1)

Torch类都实现__index,它将在metatable中查找__index__,用于重载。

来自docs

  

如果想在元类中提供索引 newindex ,   这些运营商必须遵循特定的方案:

     

index 必须返回值且为true或仅返回false。在第一种情况下,它意味着索引能够处理给定的   参数(例如,类型是正确的)。第二种情况意味着它   无法做任何事情,所以在根元表中的__index可以   尝试查看元类是否包含所需的值。

对于该示例,这意味着__index__(不是__index!)方法必须检查是否type(v) == "number",如果不是,则返回false以便__index可以在对象metatable中查找值。

local torch = require("torch")

do 
    Dataset = torch.class("Dataset")

    function Dataset:__init(i, t)
        self.inputs = i
        self.targets = t
    end

function Dataset.__index__(t, v)
    if type(v) == "number" then
        local tbl =  {
            t.inputs[v],
            t.targets[v]
        }
        return tbl, true
    else
        return false
    end
end

local dset = Dataset({0, 1, 1, 0}, {1, 1, 1, 0})
dset[1] --> {0, 1}