Spark:如何连接两个`数据集A和B,条件是A的ID数组列不包含B的ID列?

时间:2018-02-06 05:33:23

标签: scala apache-spark join apache-spark-dataset

我的问题不是[Joining Spark Dataframes with "isin" operator的重复。我的问题是“不在”,而不是“在”。它是不同的!

我有两个Dataset s:

  • userProfileDatasetDataset[UserProfile]
  • jobModelsDatasetDataset[JobModel]

案例clss UserProfile定义为

case class UserProfile(userId: Int, visitedJobIds: Array[Int])

和案例类JobModel定义为

case class JobModel(JobId: Int, Model: Map[String, Double])

我还制作了两个对象(UserProfileFieldNamesJobModelFieldNames),其中包含这两个案例类的字段名称。

我的目标是,对于userProfileDataset中的每位用户,找到JobModel.JobId 中未包含的UserProfile.visitedJobIds怎么做?

我考虑使用crossJoin然后使用filter。它可能会奏效。有更直接或有效的方法吗?

我尝试了以下方法,但没有一种方法有效:

val result = userProfileDataset.joinWith(jobModelsDataset,
      !userProfileDataset.col(UserProfileFieldNames.visitedJobIds).contains(jobModelsDataset.col(JobModelFieldNames.jobId)),
      "left_outer"
    )

导致:

  

线程“main”中的异常org.apache.spark.sql.AnalysisException:   无法解析'包含(_1visitedJobIds,CAST(_2JobId AS   STRING))'由于数据类型不匹配:参数1需要字符串类型,   但是,'_1visitedJobIds'是数组类型。;;

可能是因为contains方法只能用于测试一个字符串是否包含另一个字符串吗?

以下情况也不起作用:

!jobModelsDataset.col(JobModelFieldNames.jobId)
  .isin(userProfileDataset.col(UserProfileFieldNames.visitedJobIds))

导致:

  

线程“main”中的异常org.apache.spark.sql.AnalysisException:   由于数据无法解析'(_2JobId IN(_1visitedJobIds))'   类型不匹配:参数必须是相同的类型,但是:IntegerType!=   数组类型(IntegerType,假);; '加入LeftOuter,不是_2#74.JobId IN   (_1#73.visitedJobIds)

3 个答案:

答案 0 :(得分:1)

如果唯一作业ID的数量不是太多,那么您可以按如下方式收集和广播

val jobIds = jobModelsDataset.map(_.JobId).distinct.collect().toSeq
val broadcastedJobIds = spark.sparkContext.broadcast(jobIds)

要将此广播序列与visitedJobIds列进行比较,您可以创建UDF

val notVisited = udf((visitedJobs: Seq[Int]) => { 
  broadcastedJobIds.value.filterNot(visitedJobs.toSet)
})

val df = userProfileDataset.withColumn("jobsToDo", notVisited($"visitedJobIds"))

使用jobIds = 1,2,3,4,5和示例数据框进行测试

+------+---------------+
|userId|  visitedJobIds|
+------+---------------+
|     1|      [1, 2, 3]|
|     2|      [3, 4, 5]|
|     3|[1, 2, 3, 4, 5]|
+------+---------------+

将给出最终的数据框

+------+---------------+--------+
|userId|  visitedJobIds|jobsToDo|
+------+---------------+--------+
|     1|      [1, 2, 3]|  [4, 5]|
|     2|      [3, 4, 5]|  [1, 2]|
|     3|[1, 2, 3, 4, 5]|      []|
+------+---------------+--------+

答案 1 :(得分:1)

explode只需userProfileDataset 数组cast IntegerType join jobModelsDataset JobId collect_list {}} {em>列已经是 IntegerType 。然后最后使用import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ val temp = userProfileDataset.withColumn("visitedJobIds", explode(col("visitedJobIds"))) .withColumn("visitedJobIds", col("visitedJobIds").cast(IntegerType)) 内置函数来获得最终结果。

爆炸投射将如下所示

temp.join(jobModelsDataset, temp("visitedJobIds") === jobModelsDataset("JobId"), "left")
      .groupBy("userId")
      .agg(collect_list("visitedJobIds").as("visitedJobIds"), collect_list("JobId").as("ModelJobIds"))
    .show(false)

加入收集将如下所示

JobIds

你应该得到你想要的东西

<强>更新

如果您正在寻找与userId无关联的val list = jobModelsDataset.select(collect_list("JobId")).rdd.first()(0).asInstanceOf[collection.mutable.WrappedArray[Int]] def notContained = udf((array: collection.mutable.WrappedArray[Int]) => list.filter(x => !(array.contains(x)))) temp.join(jobModelsDataset, temp("visitedJobIds") === jobModelsDataset("JobId"), "left") .groupBy("userId") .agg(collect_list("visitedJobIds").as("visitedJobIds"), collect_list("JobId").as("ModelJobIds")) .withColumn("ModelJobIds", notContained(col("ModelJobIds"))) .show(false) ,那么您可以执行以下操作。

broadcasting

您可以通过# src/game_types.nim - where I am storing type definitions SceneLifeCycleProc* = proc() SceneObject* = ref object tags*: seq[string] active*: bool visible*: bool x, y: float Scene* = ref object name*: string sceneObjects*: seq[SceneObject] onRegister*: SceneLifeCycleProc onEnter*: SceneLifeCycleProc onUpdate*: SceneLifeCycleProc onRender*: SceneLifeCycleProc onExit*: SceneLifeCycleProc onDestroy*: SceneLifeCycleProc # src/scene_managment.nim - file exporting newScene from include game_types proc newScene* ( name: string, sceneObjects: seq[SceneObject], slc: SceneLifeCycle): Scene = new result result.name = name result.sceneObjects = sceneObjects result.onRegister = slc[0] result.onEnter = slc[1] result.onUpdate = slc[2] result.onRender = slc[3] result.onExit = slc[4] result.onDestroy = slc[5] 改善答案。

答案 2 :(得分:0)

最初我有另一种方法,使用crossJoin然后filter

val result = userProfileDataset
  .crossJoin(jobModelsDataset) // 27353040 rows
  .filter(row => !row(2).asInstanceOf[Seq[Int]].contains(row.getInt(3))) //27352633 rows

如果我使用@ Shaido的方法然后explode,我应该能够获得与此方法相同的结果。然而,即使在我的情况下使用filter,这种方法也非常昂贵(我比较了经过的时间)。 explain方法还可以打印出物理计划。

所以我不会使用crossJoin方法。我只想发布并保留在这里。