是否可以为火炬类设置__index方法?我试图实现一个简单的dataset
课程,如“深度学习与火炬”中所述:(ipynb here)
trainset = {
inputs = {0, 1, 1, 0},
targets = {1, 1, 1, 0}
}
index = function(t, i)
return {t.inputs[i], t.targets[i]}
end
setmetatable(trainset, {
__index = index
)
允许您执行返回trainset[1]]
的{{1}}。
但是,将此作为火炬类实现不起作用。
{0, 1}
似乎在创建对象时,local torch = require("torch")
do
Dataset = torch.class("Dataset")
function Dataset:__init(i, t)
self.inputs = i
self.targets = t
end
function Dataset.__index(t, v)
print("inside index")
return {
rawget(t, inputs)[v],
rawget(t, targets)[v]
}
end
end
Dataset({0, 1, 1, 0}, {1, 1, 1, 0}) -- fails
被调用并失败,因为尚未创建__index()
和index
。如果未使用targets
,则会导致堆栈溢出。
我对Lua的理解是有限的,但我很惊讶在对象创建过程中看到rawget
被调用:我认为幕后有些东西我不完全理解。
答案 0 :(得分:1)
Torch类都实现__index
,它将在metatable中查找__index__
,用于重载。
来自docs:
如果想在元类中提供索引或 newindex , 这些运营商必须遵循特定的方案:
index 必须返回值且为true或仅返回false。在第一种情况下,它意味着索引能够处理给定的 参数(例如,类型是正确的)。第二种情况意味着它 无法做任何事情,所以在根元表中的__index可以 尝试查看元类是否包含所需的值。
对于该示例,这意味着__index__
(不是__index
!)方法必须检查是否type(v) == "number"
,如果不是,则返回false
以便__index
可以在对象metatable中查找值。
local torch = require("torch")
do
Dataset = torch.class("Dataset")
function Dataset:__init(i, t)
self.inputs = i
self.targets = t
end
function Dataset.__index__(t, v)
if type(v) == "number" then
local tbl = {
t.inputs[v],
t.targets[v]
}
return tbl, true
else
return false
end
end
local dset = Dataset({0, 1, 1, 0}, {1, 1, 1, 0})
dset[1] --> {0, 1}