在lua中编写两个函数

时间:2014-11-27 12:27:58

标签: function lua

我刚刚开始学习lua,所以我所要求的可能是不可能的。

现在,我有一个接受函数的方法:

function adjust_focused_window(fn)
  local win = window.focusedwindow()
  local winframe = win:frame()
  local screenrect = win:screen():frame()
  local f, s = fn(winframe, screenrect)
  win:setframe(f)
end

我有几个函数接受这些框架和矩形(只显示一个):

function full_height(winframe, screenrect)
   print ("called full_height for " .. tostring(winframe))
  local f = {
     x = winframe.x,
     y = screenrect.y,
     w = winframe.w,
     h = screenrect.h,
  }
  return f, screenrect
end

然后,我可以做以下事情:

hotkey.bind(scmdalt, '-', function() adjust_focused_window(full_width) end)

现在,如何在不改变它的定义的情况下将几个函数组合到adjust_focused_window。类似的东西:

hotkey.bind(scmdalt, '=', function() adjust_focused_window(compose(full_width, full_height)) end)

其中compose2将返回一个接受与full_widthfull_height相同的参数的函数,并在内部执行以下操作:

full_height(full_width(...))

2 个答案:

答案 0 :(得分:3)

正如评论中所提到的,要将两个函数链接在一起,您可以这样做:

function compose(f1, f2)
  return function(...) return f1(f2(...)) end
end

但是,如果您想将两个以上的功能连接在一起怎么办?你可能会问,是否有可能“组合”任意数量的函数?

答案肯定是肯定的 - 下面我将展示3种不同的实施方法,并快速总结其后果。

迭代表方法

这里的想法是依次调用列表中的每个函数。执行此操作时,将上一次调用的返回结果保存到表中,然后解压缩该表并将其传递给下一个调用。

function compose1(...)
    local fnchain = check_functions {...}
    return function(...)
        local args = {...}
        for _, fn in ipairs(fnchain) do
            args = {fn(unpack(args))}
        end
        return unpack(args)
    end
end

上面的check_functions助手只检查传入的内容确实是函数 - 如果没有则引发错误。为简洁起见,省略了实施。

+ :合理直接的方法。可能是你第一次尝试时想出来的。

- :资源效率不高。很多垃圾表来存储调用之间的结果。您还必须处理包装和拆包结果。

Y-Combinator Pattern

这里的关键见解是,即使我们调用的函数不是递归的,也可以通过在递归函数上捎带它来使其递归。

function compose2(...)
  local fnchain = check_functions {...}
  local function recurse(i, ...)
    if i == #fnchain then return fnchain[i](...) end
    return recurse(i + 1, fnchain[i](...))
  end
  return function(...) return recurse(1, ...) end
end

+ :不会像上面那样创建额外的临时表。仔细编写为tail-recursive - 这意味着调用long函数链不需要额外的堆栈空间。它有一定的优雅。

元脚本生成

使用最后一种方法,您可以使用lua函数实际生成执行所需函数调用链的精确lua代码。

function compose3(...)
    local luacode = 
    [[
        return function(%s)
            return function(...)
                return %s
            end
        end
    ]]
    local paramtable = {}
    local fcount = select('#', ...)
    for i = 1, fcount do
        table.insert(paramtable, "P" .. i)
    end
    local paramcode = table.concat(paramtable, ",")
    local callcode = table.concat(paramtable, "(") ..
                     "(...)" .. string.rep(')', fcount - 1)
    luacode = luacode:format(paramcode, callcode)
    return loadstring(luacode)()(...)
end

loadstring(luacode)()(...)可能需要一些解释。在这里,我选择在生成的脚本中将函数链编码为参数名称(P1, P2, P3等)。额外的()括号用于“展开”嵌套函数,因此最内部的函数是返回的函数。 P1, P2, P3 ... Pn参数成为链中每个函数的上升值,例如

function(...)
  return P1(P2(P3(...)))
end

注意,您也可以使用setfenv完成此操作,但我选择此路由只是为了避免lua 5.1和5.2之间关于如何设置功能环境的重大变化。

+ :避免使用类似方法#2的额外中间表。不滥用堆栈。

- :需要额外的字节码编译步骤。

答案 1 :(得分:0)

您可以遍历传递的函数,使用前一个函数的结果依次调用链中的下一个函数。

function module._compose(...)
  local n = select('#', ...)
  local args = { n = n, ... }
  local currFn = nil

  for _, nextFn in ipairs(args) do
    if type(nextFn) == 'function' then
      if currFn == nil then
        currFn = nextFn
      else
        currFn = (function(prev, next)
          return function(...)
            return next(prev(...))
          end
        end)(currFn, nextFn)
      end
    end
  end

  return currFn
end

注意上面Immediately Invoked Function Expressions的使用,它允许重用的函数变量不会调用无限递归循环,这发生在以下代码中:

function module._compose(...)
  local n = select('#', ...)
  local args = { n = n, ... }
  local currFn = nil

  for _, nextFn in ipairs(args) do
    if type(nextFn) == 'function' then
      if currFn == nil then
        currFn = nextFn
      else
        currFn = function(...)
          return nextFn(currFn(...)) -- this will loop forever, due to closure
        end
      end
    end
  end

  return currFn
end

虽然 Lua 不支持三元运算符,但可以使用 short-circuit evaluation 来移除内部的 if 语句:

function module.compose(...)
  local n = select('#', ...)
  local args = { n = n, ... }
  local currFn = nil

  for _, nextFn in ipairs(args) do
    if type(nextFn) == 'function' then
      currFn = currFn and (function(prev, next)
        return function(...)
          return next(prev(...))
        end
      end)(currFn, nextFn) or nextFn
    end
  end

  return currFn
end