对于Scala中这种特殊的varargs案例,这是一种更好的方法

时间:2014-06-12 19:05:56

标签: scala variadic-functions

我在Scala上有以下定义

     def seq(stms: Stm*): Stm = if (stms.isEmpty) EXP(CONST(0)) else stms reduce SEQ

我用它来写stms:

  ESEQ(
    seq(MOVE(TEMP(r), CONST(1)),
      genstm(t, f),
      LABEL(f),
      MOVE(TEMP(r), CONST(0)),
      LABEL(t)),
    TEMP(r))

但在某些情况下我需要“序列的最后一个元素作为列表”所以,我写道:

ESEQ (seq(
  EXP(extenalCall("_newRecord", expressions.length)) ::
  MOVE(rt, Frame.RV) ::
  values:_*
  ), rt)

我希望在调用“seq”方面更加同质化。我做了方法重载:

  def seq(stms: List[Stm]): Stm = if (stms.isEmpty) EXP(CONST(0)) else stms reduce SEQ

  def seq(s1:Stm, l:List[Stm]) = seq(s1 :: l)
  def seq(s1:Stm, s2:Stm, l:List[Stm]) = seq(s1 :: s2 :: l)
  def seq(s1:Stm, s2:Stm, s3:Stm, l:List[Stm]) = seq(s1 :: s2 :: s3 :: l)
  def seq(s1:Stm, s2:Stm, s3:Stm, s4:Stm, l:List[Stm]) = seq(s1 :: s2 :: s3 :: s4 ::  l)
  def seq(s1:Stm, s2:Stm, s3:Stm, s4:Stm, s5:Stm, l:List[Stm]) = seq(s1 :: s2 :: s3 :: s4 :: s5 :: l)

  def seq(s:Stm*) = seq(s.toList)

为了写下这样的最后一个片段:

ESEQ (seq(
  EXP(extenalCall("_newRecord", expressions.length)),
  MOVE(rt, Frame.RV),
  values
  ), rt)

这有不同的方法吗?

2 个答案:

答案 0 :(得分:3)

我能想到的最好的想法是定义从StmList[Stm]的隐式转换,然后接受任意数量的List[Stm]

implicit def stm2list(stm: Stm): List[Stm] = List(stm)

def seq(stms: List[Stm]*): Stm = {
  val flat = stms.flatten
  if (flat.isEmpty) EXP(CONST(0)) else stms reduce SEQ
}

但是它对于List[Stm]的任何第一个元素也是有效的,而不仅仅是最后一个。

它还引入了危险的隐式转换。您可以使用一个将List[Stm]包装在另一个类型中的小类保护它,然后提供2个隐式转换:

class ListOfStmOrStm(val stms: List[Stm])
object ListOfStmOrStm {
  implicit def fromStm(stm: Stm): ListOfStmOrStm = new ListOfStmOrStm(List(stm))
  implicit def fromList(stms: List[Stm]): ListOfStmOrStm = new ListOfStmOrStm(stms)
}

def seq(stms: ListOfStmOrStm*): Stm = {
  val flat = stms.flatMap(_.stms)
  if (flat.isEmpty) EXP(CONST(0)) else stms reduce SEQ
}

注意:将ListOfStmOrStm设为值类是没用的,因为必须将其装箱以存储在Seq参数附带的*中。

答案 1 :(得分:0)

您可以使用Seq()++继续使用逗号:

ESEQ(seq(Seq(
    x,
    y,
    z) ++ values: _*),
  ...)

或者,如果你能以

的方式写seq()
seq(xs: _*) ++ seq(ys: _*) == seq((xs ++ ys): _*)

然后你可以避免写Seq()

ESEQ(seq(
  x,
  y,
  z) ++ seq(values: _*),
  ...)