模式匹配以查找Spark Row

时间:2017-03-06 15:43:29

标签: scala apache-spark

我试图编写函数来乘以Spark Rdd中的所有数字。只有一列有数字。 问题是:我不知道它将是什么类型的数字值。所以我试着这样做:

        def mult(a: Row, b: Row): Row = {
          (a(0), b(0)) match {
            case (v: java.lang.Double, v1: java.lang.Double) => Row(v.asInstanceOf[Double] * v1.asInstanceOf[Double])
            case (v: java.lang.Long, v1: java.lang.Long) => Row((v.asInstanceOf[Long] * v1.asInstanceOf[Long]).toDouble)
            case (v: java.lang.Float, v1: java.lang.Float) => Row((v.asInstanceOf[Float] * v1.asInstanceOf[Float]).toDouble)
            case (v: java.lang.Integer, v1: java.lang.Integer) => Row((v.asInstanceOf[Int] * v1.asInstanceOf[Int]).toDouble)
            case _ => throw new Exception ("Incorrect data type in column")
          }
        }

    val result = df.rdd.reduce((a, b) => mult(a, b))

就像这样:

    def mult(a: Row, b: Row): Row = {
      (a(0), b(0)) match {
        case (v: DoubleType, v1: DoubleType) => Row(v.asInstanceOf[Double] * v1.asInstanceOf[Double])
        case (v: LongType, v1: LongType) => Row((v.asInstanceOf[Long] * v1.asInstanceOf[Long]).toDouble)
        case (v: FloatType, v1: FloatType) => Row((v.asInstanceOf[Float] * v1.asInstanceOf[Float]).toDouble)
        case (v: IntegerType, v1: IntegerType) => Row((v.asInstanceOf[Int] * v1.asInstanceOf[Int]).toDouble)
        case _ => throw new Exception ("Incorrect data type in column")
      }
    }

但是在运行时我正在

  

java.lang.Exception:列

中的数据类型不正确

当我做

println(v.getClass)

我总是得到

  

class java.lang.Double

我用这些数据测试它:

val data = List(
        List(444.1235D),
        List(67.5335D),
        List(69.5335D),
        List(677.5335D),
        List(47.5335D)
      )

      val rdd = sparkContext.parallelize(data).map(Row.fromSeq(_))
      val schema = StructType(Array(
        StructField("value", DataTypes.DoubleType, false)
      ))

      val df = sqlContext.createDataFrame(rdd, schema)
      df.createOrReplaceTempView(tableName)

然后同样的Long:

  val data1 = List(
        List(555L),
        List(955L),
        List(575L),
        List(355L),
        List(615L),
        List(0L)
      )

      val rdd1 = sparkContext.parallelize(data1).map(Row.fromSeq(_))
      val schema1 = StructType(Array(
        StructField("value", DataTypes.LongType, false)
      ))

      val df1 = sqlContext.createDataFrame(rdd1, schema1)
      df1.createOrReplaceTempView(tableName1)

我做错了什么以及如何实现这一目标?

0 个答案:

没有答案