根据列标识

时间:2018-02-22 15:27:29

标签: scala apache-spark

我有两个数据帧。 第一个看起来像这样(通道的数量将根据类型而变化)此数据帧存储设备的类型和每个通道的值。

+-----+----------+----------+
| Type|X_ChannelA|Y_ChannelB|
+-----+----------+----------+
|TypeA|        11|        20|
+-----+----------+----------+

第二个数据帧是从csv导入的,由我生成。 现在我有这种格式(可以改为需要的任何东西)

+-----+--------------+--------------+--------------+--------------+
| Type|X_ChannelA_min|X_ChannelA_max|Y_ChannelB_min|Y_ChannelB_max|
+-----+--------------+--------------+--------------+--------------+
|TypeA|             8|            12|             9|            13|
+-----+--------------+--------------+--------------+--------------+

现在我想将实际的Channel值与min和max值进行比较,并创建一个带有_status的新列,如果值介于min和max之间,则包含一个值,如果超过min或max,则为零。

这个例子的结果

+-----+----------+----------+-----------------+-----------------+
| Type|X_ChannelA|Y_ChannelB|X_ChannelA_status|Y_ChannelB_status|
+-----+----------+----------+-----------------+-----------------+
|TypeA|        11|        20|                1|                0|
+-----+----------+----------+-----------------+-----------------+

代码在这里:

    val df_orig = spark.sparkContext.parallelize(Seq(
      ("TypeA", 11, 20)
    )).toDF("Type", "X_ChannelA", "Y_ChannelB")

    val df_def = spark.sparkContext.parallelize(Seq(
      ("TypeA", 8, 12, 9, 13)
    )).toDF("Type", "X_ChannelA_min", "X_ChannelA_max", "Y_ChannelB_min", "Y_ChannelB_max")

我已经尝试了一些不同的事情,但事实并非如此 就像通过获取所有通道的字符串数组然后使用

创建collumsn来创建列
val pattern = """[XYZP]_Channel.*"""
val fieldNames = df_orig.schema.fieldNames.filter(_.matches(pattern))
fieldNames.foreach(x => df.withColumn(s"${x}_status", <compare logic comes here>)

我的下一个想法是将df_orig与df_def连接,然后将channel_value,channel_min,channel_max与concat_ws一起添加到单个列中,将这些值与比较逻辑进行比较并将结果写入列

+-----+----------+----------+----------------+----------------+-------------+...
| Type|X_ChannelA|Y_ChannelB|X_ChannelA_array|Y_ChannelB_array|X_ChannelA_st|
+-----+----------+----------+----------------+----------------+-------------+...
|TypeA|        11|        20|     [11, 8, 12]|     [20, 9, 13]|            1|
+-----+----------+----------+----------------+----------------+-------------+...

如果有一个更简单的解决方案,那么推进正确的方向会很好。

编辑:如果我的描述基本上不清楚我正在寻找的是: 我正在寻找的是

foreach channel in channellist (
    ds.withColumn("<channel>_status", when($"<channel>" < $"<channel>_min" || $"<channel>" > $"<channel>_max"), 1).otherwise 0)
)

编辑:我找到了一个解决方案:

val df_joined = df_orig.join(df_def, Seq("Type"))
val pattern = """[XYZP]_Channel.*"""
val fieldNames = df_orig.schema.fieldNames.filter(_.matches(pattern))
val df_newnew = df_joined.select(col("*") +: (fieldNames.map(c => when(col(c) <= col(c+"_min") || col(c) >= col(c+"_max"), 1).otherwise(0).as(c+"_status))): _*)

1 个答案:

答案 0 :(得分:1)

join是要走的路。您必须正确使用when功能,如下所示

import org.apache.spark.sql.functions._
df_orig.join(df_def, Seq("Type"), "left")
  .withColumn("X_ChannelA_status", when(col("X_ChannelA") >= col("X_ChannelA_min") && col("X_ChannelA") <= col("X_ChannelA_max"), 1).otherwise(0))
  .withColumn("Y_ChannelB_status", when(col("Y_ChannelB") >= col("Y_ChannelB_min") && col("Y_ChannelB") <= col("Y_ChannelB_max"), 1).otherwise(0))
  .select("Type", "X_ChannelA", "Y_ChannelB", "X_ChannelA_status", "Y_ChannelB_status")

你应该得到你想要的输出

+-----+----------+----------+-----------------+-----------------+
|Type |X_ChannelA|Y_ChannelB|X_ChannelA_status|Y_ChannelB_status|
+-----+----------+----------+-----------------+-----------------+
|TypeA|11        |20        |1                |0                |
+-----+----------+----------+-----------------+-----------------+

<强>更新

如果您的频道数据框中有更多列,如果您不想如上所述对所有列进行硬编码,那么您可以从foldLeft中受益scala中的函数)

但在此之前,您必须决定要迭代的列(即频道)

val df_orig_Columns = df_orig.columns
val columnsToIterate = df_orig_Columns.toSet - "Type"

然后在join之后,使用foldLeft来概括上述withColumn流程

val joinedDF = df_orig.join(df_def, Seq("Type"), "left")

import org.apache.spark.sql.functions._
val finalDF = columnsToIterate.foldLeft(joinedDF){(tempDF, colName) => tempDF.withColumn(colName+"_status", when(col(colName) >= col(colName+"_min") && col(colName) <= col(colName+"_max"), 1).otherwise(0))}

最后,您select必要的列

val finalDFcolumns = df_orig_Columns ++ columnsToIterate.map(_+"_status")
finalDF.select(finalDFcolumns.map(col): _*)

我猜是的。希望它不仅仅是有用的