使用Scala的Spark:按另一个列表中包含的值过滤RDD

时间:2016-06-11 12:14:07

标签: scala apache-spark

我将如何按list.contains()过滤? 这是我当前的代码,我有一个Main类,它从命令行参数获取输入,并根据该输入执行相应的调度程序。在这种情况下,它的RecommendationDispatcher类在构造函数中完成所有它的魔术 - 训练模型并为输入的各种用户生成推荐:

import org.apache.commons.lang.StringUtils.indexOfAny
import java.io.{BufferedWriter, File, FileWriter}
import java.text.DecimalFormat
import Util.javaHash
import org.apache.spark.mllib.recommendation.ALS
import org.apache.spark.mllib.recommendation.Rating


import org.apache.spark.{SparkConf, SparkContext}

class RecommendDispatcher(master:String, inputFile:String, outputFile:String, userList: List[String]) extends java.io.Serializable {

  val format : DecimalFormat = new DecimalFormat("0.#####");
  val file = new File(outputFile)
  val bw = new BufferedWriter(new FileWriter(file))

  val conf = new SparkConf().setAppName("Movies").setMaster(master)
  val sparkContext = new SparkContext(conf)
  val sqlContext = new org.apache.spark.sql.SQLContext(sparkContext)
  val baseRdd = sparkContext.textFile(inputFile)




  val movieIds = baseRdd.map(line => line.split("\\s+")(1)).distinct().map(id => (javaHash(id), id))

  val userIds = baseRdd.map(line => line.split("\\s+")(3)).distinct()
                                        .filter(x => userList.contains(x))
                                        .map(id => (javaHash(id), id))


  val ratings = baseRdd.map(line => line.split("\\s+"))
    .map(tokens => (tokens(3),tokens(1), tokens(tokens.indexOf("review/score:")+1).toDouble))
      .map( x => Rating(javaHash(x._1),javaHash(x._2),x._3))

  // Build the recommendation model using ALS
  val rank = 10
  val numIterations = 10
  val model = ALS.train(ratings, rank, numIterations, 0.01)

  val users = userIds.collect()
  var mids = movieIds.collect()

    usrs.foreach(u => {
      bw.write("Recommendations for " + u + ":\n")
      var ranked = List[(Double, Int)]()
      mids.foreach(x => {
        val movieId = x._1
        val prediction = (model.predict(u._1, movieId), movieId)
        ranked = ranked :+ prediction
      })
      //Sort in descending order
      ranked = ranked.sortBy(x => -1 * x._1)
      ranked.foreach(x => bw.write(x._1 + " ; " + x._2 + "\n"))
    })

  bw.close()

}

这个异常会被放在" .filter" line:

  

线程中的异常" main" org.apache.spark.SparkException:任务没有   序列化

3 个答案:

答案 0 :(得分:2)

我认为一个好方法是将userList转换为broadcast variable

val broadcastUserList= sc.broadcast(userList)
val userIds = baseRdd.map(line => line.split("\\s+")(3)).distinct()
                                      .filter(x => broadcastUserList.value.contains(x))
                                      .map(id => (javaHash(id), id))

答案 1 :(得分:0)

我猜Sim是正确的关闭“泄漏”,你提供的示例代码过于简单。

如果你的主要是这样的:

object test
{
  def main(args: Array[String]): Unit = 
  {
    val sc = ...
    val rdd1 = ...
    val userList = ...
    val rdd2 = rdd1.filter { list.contains( _ ) }
  } 
}

然后没有发生序列化错误。 “userList”,可序列化,没有问题可以序列化到执行程序......

当您开始将“大”主文件建模为单独的类时,问题就开始了。

以下是事情可能出错的一个例子:

class FilterLogic
{
  val userList = List( 1 )  
  def filterRDD( rdd : RDD[ Int ] ) : RDD[ Int ] = 
  {
    rdd.filter { list.contains( _ ) }
  }
}

object Test 
{
  def main(args: Array[String]): Unit = 
  {
    val sc = ...
    val rdd1 = ...
    val rdd2 = new FilterLogic().filterRDD( rdd1 )// This will result in a serialization error!!!
  }
}

既然userList是Logic类的值,当它需要被序列化到执行器时,它还要求整个包装逻辑类被序列化(为什么?因为在Scala中userList实际上是逻辑中的一个getter ())。

解决此问题的几种方法:

1)userList可以在filterRDD函数内创建,然后它不是Logic的val(工作但限制代码共享/建模)

1.1)类似的想法是在filterRDD函数中使用temp val,如下所示:

val list_ = list ; rdd.filter { list_.contains( _ ) }

有效,但很难看,这几乎是痛苦的......

2)可以将逻辑类设置为Serializable(有时可能无法使其序列化)

最后,使用广播可能有(或没有)它的好处,但它与序列化错误无关。

答案 2 :(得分:0)

我尝试序列化了RecommendDispatcher类,但仍然遇到了相同的异常。所以我决定把代码放在Main类中,这解决了我的问题。