使用模式匹配过滤`for`

时间:2014-02-19 13:15:03

标签: scala iterator pattern-matching

我正在阅读TSV文件并使用类似的东西:

case class Entry(entryType: Int, value: Int)

def filterEntries(): Iterator[Entry] = {
  for {
    line <- scala.io.Source.fromFile("filename").getLines()
  } yield new Entry(line.split("\t").map(x => x.toInt))
}

现在我有兴趣过滤掉entryType设置为0的条目,忽略列数大于或小于2的行(与构造函数不匹配)。我想知道是否有一种惯用的方法来实现这一点可能是在伴侣对象中使用模式匹配和unapply方法。我唯一能想到的就是在生成的迭代器上使用.filter

我也会接受不涉及for循环的解决方案,但会返回Iterator[Entry]。他们的解决方案必须能够容忍格式错误的输入。

4 个答案:

答案 0 :(得分:2)

这更具有现实意义:

package object liner {
  implicit class R(val sc: StringContext) {
    object r {
      def unapplySeq(s: String): Option[Seq[String]] = sc.parts.mkString.r unapplySeq s
    }
  }
}

package liner {

  case class Entry(entryType: Int, value: Int)

  object I {
    def unapply(s: String): Option[Int] = util.Try(s.toInt).toOption
  }

  object Test extends App {
    def lines = List("1 2", "3", "", "  4  5  ", "junk", "0, 100000", "6 7 8")

    def entries = lines flatMap {
      case r"""\s*${I(i)}(\d+)\s+${I(j)}(\d+)\s*""" if i != 0 => Some(Entry(i, j))
      case __________________________________________________ => None
    }
    Console println entries
  }
}

希望正则表达式插值器能够很快进入标准发行版,但这表明它的装配是多么容易。同样希望,scanf样式的插值器可以使用case f"$i%d"轻松提取。

我刚开始在模式中使用“细长通配符”来对齐箭头。

有一个蛹或幼虫正则表达式宏:

https://github.com/som-snytt/regextractor

答案 1 :(得分:0)

您可以在for-comprehension的头部创建变量,然后使用guard:

编辑:确保数组的长度

for {
  line <- scala.io.Source.fromFile("filename").getLines()
  arr = line.split("\t").map(x => x.toInt)
  if arr.size == 2 && arr(0) != 0
} yield new Entry(arr(0), arr(1))

答案 2 :(得分:0)

for循环中<-=的左侧可能是完全成熟的模式。所以你可以这样写:

def filterEntries(): Iterator[Int] = for {
  line <- scala.io.Source.fromFile("filename").getLines()
  arr = line.split("\t").map(x => x.toInt)
  if arr.size == 2
  // now you may use pattern matching to extract the array
  Array(entryType, value) = arr
  if entryType == 0
} yield Entry(entryType, value)

请注意,如果某个字段无法转换为Int,则此解决方案将抛出NumberFormatException。如果您不希望这样,那么您必须再次使用x.toInt封装Try并再次进行模式匹配。

答案 3 :(得分:0)

我使用以下代码解决了它:

import scala.util.{Try, Success}

val lines = List(
  "1\t2",
  "1\t",
  "2",
  "hello",
  "1\t3"
)

case class Entry(val entryType: Int, val value: Int)
object Entry {
  def unapply(line: String) = {
    line.split("\t").map(x => Try(x.toInt)) match {
      case Array(Success(entryType: Int), Success(value: Int)) => Some(Entry(entryType, value))
      case _ =>
        println("Malformed line: " + line)
        None
    }
  }
}

for {
  line <- lines
  entryOption = Entry.unapply(line)
  if entryOption.isDefined
} yield entryOption.get