我有一个包含三列的数据框:id
,index
和value
。
+---+-----+-------------------+
| id|index| value|
+---+-----+-------------------+
| A| 1023|0.09938822262205915|
| A| 1046| 0.3110047630613805|
| A| 1069| 0.8486710971453512|
+---+-----+-------------------+
root
|-- id: string (nullable = true)
|-- index: integer (nullable = false)
|-- value: double (nullable = false)
然后,我有另一个数据框,其中显示了每个id
的期望周期:
+---+-----------+---------+
| id|start_index|end_index|
+---+-----------+---------+
| A| 1069| 1276|
| B| 2066| 2291|
| B| 1616| 1841|
| C| 3716| 3932|
+---+-----------+---------+
root
|-- id: string (nullable = true)
|-- start_index: integer (nullable = false)
|-- end_index: integer (nullable = false)
我有以下三个模板
val template1 = Array(0.0, 0.1, 0.15, 0.2, 0.3, 0.33, 0.42, 0.51, 0.61, 0.7)
val template2 = Array(0.96, 0.89, 0.82, 0.76, 0.71, 0.65, 0.57, 0.51, 0.41, 0.35)
val template3 = Array(0.0, 0.07, 0.21, 0.41, 0.53, 0.42, 0.34, 0.25, 0.19, 0.06)
目标是,对于dfIntervals
中的每一行,应用一个函数(假设它是相关的),其中该函数从value
接收dfRaw
列和三个模板数组,并添加三个dfIntervals
列,每个列与每个模板相关。
假设: 1-模板数组的大小恰好是10。
2-{{1}}的{{1}}列中没有重复项
3-{{1}中的index
和dfRaw
列存在于start_index
的{{1}}列中,并且它们之间恰好有10行。例如,end_index
(dfIntervals中的第一行)将精确地生成dfIntervals
。
以下是生成这些数据帧的代码:
index
结果将三个列添加到dfRaw
数据框中,名称分别为dfRaw.filter($"id" === "A").filter($"index" >= 1069 && $"index" <= 1276).count
,10
和import org.apache.spark.sql.functions._
val mySeed = 1000
/* Defining templates for correlation analysis*/
val template1 = Array(0.0, 0.1, 0.15, 0.2, 0.3, 0.33, 0.42, 0.51, 0.61, 0.7)
val template2 = Array(0.96, 0.89, 0.82, 0.76, 0.71, 0.65, 0.57, 0.51, 0.41, 0.35)
val template3 = Array(0.0, 0.07, 0.21, 0.41, 0.53, 0.42, 0.34, 0.25, 0.19, 0.06)
/* Defining raw data*/
var dfRaw = Seq(
("A", (1023 to 1603 by 23).toArray),
("B", (341 to 2300 by 25).toArray),
("C", (2756 to 3954 by 24).toArray)
).toDF("id", "index")
dfRaw = dfRaw.select($"id", explode($"index") as "index").withColumn("value", rand(seed=mySeed))
/* Defining intervals*/
var dfIntervals = Seq(
("A", 1069, 1276),
("B", 2066, 2291),
("B", 1616, 1841),
("C", 3716, 3932)
).toDF("id", "start_index", "end_index")
PS:我在Scala中找不到相关函数。假设存在这样的函数(如下所示),并且我们将根据需要创建一个dfIntervals
。
corr_w_template1
答案 0 :(得分:1)
好的。
让我们定义一个UDF函数。
出于测试目的,假设它将始终返回1。
val correlation = functions.udf( (values: mutable.WrappedArray[Double], template: mutable.WrappedArray[Double]) => {
1f
})
val orderUdf = udf((values: mutable.WrappedArray[Row]) => {
values.sortBy(r => r.getAs[Int](0)).map(r => r.getAs[Double](1))
})
然后将您的2个数据框与已定义的规则结合起来,并将value
收集到称为values
的1列中。另外,应用我们的orderUdf
val df = dfIntervals.join(dfRaw,dfIntervals("id") === dfRaw("id") && dfIntervals("start_index") <= dfRaw("index") && dfRaw("index") <= dfIntervals("end_index") )
.groupBy(dfIntervals("id"), dfIntervals("start_index"), dfIntervals("end_index"))
.agg(orderUdf(collect_list(struct(dfRaw("index"), dfRaw("value")))).as("values"))
最后,应用我们的udf并显示出来。
df.withColumn("corr_w_template1",correlation(df("values"), lit(template1)))
.withColumn("corr_w_template2",correlation(df("values"), lit(template2)))
.withColumn("corr_w_template3",correlation(df("values"), lit(template3)))
.show(10)
其中充满了示例代码:
import org.apache.spark.sql.functions._
import scala.collection.JavaConverters._
val conf = new SparkConf().setAppName("learning").setMaster("local[2]")
val session = SparkSession.builder().config(conf).getOrCreate()
val mySeed = 1000
/* Defining templates for correlation analysis*/
val template1 = Array(0.0, 0.1, 0.15, 0.2, 0.3, 0.33, 0.42, 0.51, 0.61, 0.7)
val template2 = Array(0.96, 0.89, 0.82, 0.76, 0.71, 0.65, 0.57, 0.51, 0.41, 0.35)
val template3 = Array(0.0, 0.07, 0.21, 0.41, 0.53, 0.42, 0.34, 0.25, 0.19, 0.06)
val schema1 = DataTypes.createStructType(Array(
DataTypes.createStructField("id",DataTypes.StringType,false),
DataTypes.createStructField("index",DataTypes.createArrayType(DataTypes.IntegerType),false)
))
val schema2 = DataTypes.createStructType(Array(
DataTypes.createStructField("id",DataTypes.StringType,false),
DataTypes.createStructField("start_index",DataTypes.IntegerType,false),
DataTypes.createStructField("end_index",DataTypes.IntegerType,false)
))
/* Defining raw data*/
var dfRaw = session.createDataFrame(Seq(
("A", (1023 to 1603 by 23).toArray),
("B", (341 to 2300 by 25).toArray),
("C", (2756 to 3954 by 24).toArray)
).map(r => Row(r._1 , r._2)).asJava, schema1)
dfRaw = dfRaw.select(dfRaw("id"), explode(dfRaw("index")) as "index")
.withColumn("value", rand(seed=mySeed))
/* Defining intervals*/
var dfIntervals = session.createDataFrame(Seq(
("A", 1069, 1276),
("B", 2066, 2291),
("B", 1616, 1841),
("C", 3716, 3932)
).map(r => Row(r._1 , r._2,r._3)).asJava, schema2)
//Define udf
val correlation = functions.udf( (values: mutable.WrappedArray[Double], template: mutable.WrappedArray[Double]) => {
1f
})
val orderUdf = udf((values: mutable.WrappedArray[Row]) => {
values.sortBy(r => r.getAs[Int](0)).map(r => r.getAs[Double](1))
})
val df = dfIntervals.join(dfRaw,dfIntervals("id") === dfRaw("id") && dfIntervals("start_index") <= dfRaw("index") && dfRaw("index") <= dfIntervals("end_index") )
.groupBy(dfIntervals("id"), dfIntervals("start_index"), dfIntervals("end_index"))
.agg(orderUdf(collect_list(struct(dfRaw("index"), dfRaw("value")))).as("values"))
df.withColumn("corr_w_template1",correlation(df("values"), lit(template1)))
.withColumn("corr_w_template2",correlation(df("values"), lit(template2)))
.withColumn("corr_w_template3",correlation(df("values"), lit(template3)))
.show(10,false)