如何为任意嵌套长度的嵌套数组指定返回类型?

时间:2018-05-04 22:30:34

标签: arrays scala return-type

假设,作为一个例子,我想要一个递归地将数组包装在另一个数组中的函数,n次。

换句话说,期望的结果是:

wrap(Array(1,2,3), 2) = Array(Array(Array(1,2,3)))
wrap(Array(4,5,6), 3) = Array(Array(Array(Array(4,5,6))))

如何指定返回类型?这取决于n。假设输入的类型为Array[A]

对于n=1,它是Array[Array[A]]

对于n=3,它是Array[Array[Array[Array[A]]]]

我们可以使用Array[_]之类的:

def wrap[A:ClassTag](x:Array[A], n:Int):Array[_] = { 
  if (n == 1) {
    Array(x)
  } else {
    wrap(Array(x), n-1)
  }
}

但编译器并不知道元素是Array s:

> val y = wrap(Array(1,2,3), 1)
  Array[_] = Array(Array(1, 2, 3))
> y(0).length
error: value length is not a member of _$1
  y(0).length
       ^

我们可以使用asInstanceOf,但这似乎不是一个很好的解决方案:

> y(0).asInstanceOf[Array[Int]].length
  Int = 3

3 个答案:

答案 0 :(得分:1)

Array[_]是此类方法的正确类型,但并非所有类型信息都已丢失。您可以使用模式匹配来检索它。

def unwrap(a :Array[_]) :String = a match {
  case Array(sa :Array[_]) => unwrap(sa)
  case ia :Array[Int]      => ia.mkString("+")
  case x                   => x.mkString("-")
}

话虽如此,最好完全避免任意嵌套类型。

答案 1 :(得分:1)

我认为不可能有完美的解决方案,因为n的值是在运行时确定的,但返回类型需要在编译时存在。如果n不是文字的,就像在您的示例中那样,那么您可以做的最好的事情就是返回Array[_]

但是如果你总是要使用文字,那么你基本上可以在编译时将n作为类型参数传递。您传递n=1

,而不是传递A=Array[Array[Int]]
import scala.reflect.ClassTag

trait Wrapper[A, B] {
  def wrap(xs: Array[B]): A
}

implicit def wrapperBase[B] = new Wrapper[Array[B], B] {
  def wrap(xs: Array[B]) = xs
}

implicit def wrapperRec[A : ClassTag, B](implicit w: Wrapper[A, B]) = new Wrapper[Array[A], B] {
  def wrap(xs: Array[B]): Array[A] = Array(w.wrap(xs))
}

def wrap[B, A](xs: Array[B])(implicit w: Wrapper[A, B]): A = w.wrap(xs)

val xs = Array(1, 2, 3)
wrap[Int, Array[Array[Int]]](xs)  // instead of wrap(xs, 1)
wrap[Int, Array[Array[Array[Int]]]](xs)  // instead of wrap(xs, 2)

如果你想得到真正的想象,你可以进入编译时整数类型ala shapeless的Nat类,理论上做wrap[Int, _5]这样的事情,但这肯定是一个更大的兔子洞几乎没什么好处。

答案 2 :(得分:1)

以简单的方式做到这一点是不可能的。类型是编译时公民,数字在运行时存在。考虑如果从用户输入中读取数字n会发生什么。对于不同的将来用户输入,编译器应为该方法生成不同的结果类型。

如果我没有错,那么我们需要一种比Scala更好地支持依赖类型的语言。请参阅此问题:Any reason why scala does not explicitly support dependent types?,尤其是the answer by P. Frolov

也就是说,如果在编译时知道数字n,则可以表达该类型。例如,它是Int文字,final val或文字和final val的简单算术表达式。例如,在final val a = 3; wrap(Array(1,2,3), a * 2 + 1)

的情况下

以下是类型类的示例代码,它实现了这种包装。它使用shapeless库将数字文字很好地转换为Nat类型值:

import scala.reflect.{classTag, ClassTag}

abstract class Wrapper[T : ClassTag, N <: Nat] {
  // Type of Array[T] wrapped N times
  type Out 

  // ClassTag of the array wrapped N times. 
  // It's needed to be able to wrap it one more time.
  def outTag: ClassTag[Out]

  // The actual function that wraps the array
  def apply(array: Array[T]): Out 
}

object Wrapper {
  type Aux[T, N <: Nat, O] = Wrapper[T, N] { type Out = O }

  // Wrap the array 0 times. The base of the recursion.
  implicit def zero[T : ClassTag]: Aux[T, Nat._0, Array[T]] = new Wrapper[T, Nat._0] {
    type Out = Array[T]
    def outTag = classTag[T].wrap
    def apply(array: Array[T]): Out = array
  }

  // Given a Wrapper, that wraps the array N times,
  //   make a Wrapper, that wraps N + 1 times.
  implicit def next[T : ClassTag, N <: Nat](
    implicit prev: Wrapper[T, N]
  ): Aux[T, Succ[N], Array[prev.Out]] = new Wrapper[T, Succ[N]] {
    type Out = Array[prev.Out]
    def outTag = prev.outTag.wrap
    def apply(array: Array[T]): Out = Array(prev(array))(prev.outTag)
  }
}

使用此类型类的wrap函数:

def wrap[A: ClassTag](
  x: Array[A], 
  n: Nat
)(
  implicit wrapper: Wrapper[A, n.N]
): wrapper.Out = 
  wrapper(x)

编译器知道结果的类型,并且可以在没有任何类型转换的情况下使用结果:

scala> val a = wrap(Array(1,2,3), 3)
a: Array[Array[Array[Array[Int]]]] = Array(Array(Array(Array(1, 2, 3))))

scala> a.head.head.head.sum
res1: Int = 6

scala> object Foo {
  final val n = 2
  def run() = wrap(Array(1,2,3), n * 2 + 1)
} 
defined object Foo

scala> Foo.run()
res2: Array[Array[Array[Array[Array[Array[Int]]]]]] = Array(Array(Array(Array(Array(Array(1, 2, 3))))))