Scala中的交叉产品

时间:2013-02-06 22:35:40

标签: scala functional-programming cartesian-product cross-product

我想要一个二进制运算符cross(交叉产品/笛卡尔积),它与Scala中的遍历一起运行:

val x = Seq(1, 2)
val y = List('hello', 'world', 'bye')
val z = x cross y    # i can chain as many traversables e.g. x cross y cross w etc

assert z == ((1, 'hello'), (1, 'world'), (1, 'bye'), (2, 'hello'), (2, 'world'), (2, 'bye'))

仅在Scala中执行此操作的最佳方法是什么(即不使用scalaz之类的东西)?

7 个答案:

答案 0 :(得分:74)

你可以通过隐式类和Scala 2.10中的for - 理解来直截了当地做到这一点:

implicit class Crossable[X](xs: Traversable[X]) {
  def cross[Y](ys: Traversable[Y]) = for { x <- xs; y <- ys } yield (x, y)
}

val xs = Seq(1, 2)
val ys = List("hello", "world", "bye")

现在:

scala> xs cross ys
res0: Traversable[(Int, String)] = List((1,hello), (1,world), ...

这在2.10之前是可能的 - 只是不那么简洁,因为你需要定义类和隐式转换方法。

你也可以这样写:

scala> xs cross ys cross List('a, 'b)
res2: Traversable[((Int, String), Symbol)] = List(((1,hello),'a), ...

但是,如果您希望xs cross ys cross zs返回Tuple3,则需要大量样板文件或Shapeless等库。

答案 1 :(得分:15)

x_listy_list与:

交叉
val cross = x_list.flatMap(x => y_list.map(y => (x, y)))

答案 2 :(得分:10)

以下是任意数量列表的递归交叉积的实现:

def crossJoin[T](list: Traversable[Traversable[T]]): Traversable[Traversable[T]] =
  list match {
    case xs :: Nil => xs map (Traversable(_))
    case x :: xs => for {
      i <- x
      j <- crossJoin(xs)
    } yield Traversable(i) ++ j
  }

crossJoin(
  List(
    List(3, "b"),
    List(1, 8),
    List(0, "f", 4.3)
  )
)

res0: Traversable[Traversable[Any]] = List(List(3, 1, 0), List(3, 1, f), List(3, 1, 4.3), List(3, 8, 0), List(3, 8, f), List(3, 8, 4.3), List(b, 1, 0), List(b, 1, f), List(b, 1, 4.3), List(b, 8, 0), List(b, 8, f), List(b, 8, 4.3))

答案 3 :(得分:4)

猫用户的替代选择:

sequence上的

List[List[A]]创建叉积:

import cats.implicits._

val xs = List(1, 2)
val ys = List("hello", "world", "bye")

List(xs, ys).sequence 
//List(List(1, hello), List(1, world), List(1, bye), List(2, hello), List(2, world), List(2, bye))

答案 4 :(得分:2)

MPSImageLanczosScale

答案 5 :(得分:1)

这类似于Milad's response,但非递归。

def cartesianProduct[T](seqs: Seq[Seq[T]]): Seq[Seq[T]] = {
  seqs.foldLeft(Seq(Seq.empty[T]))((b, a) => b.flatMap(i => a.map(j => i ++ Seq(j))))
}

基于this blog post

答案 6 :(得分:0)

类似于其他答复,只是我的方法。

def loop(lst: List[List[Int]],acc:List[Int]): List[List[Int]] = {
  lst match {
    case head :: Nil => head.map(_ :: acc)
    case head :: tail => head.flatMap(x => loop(tail,x :: acc))
    case Nil => ???
  }
}
val l1 = List(10,20,30,40)
val l2 = List(2,4,6)
val l3 = List(3,5,7,9,11)

val lst = List(l1,l2,l3)

loop(lst,List.empty[Int])
相关问题