我正在尝试将以下pyspark代码转换为scala。如您所知,scala中的数据帧是不可变的,这迫使我转换以下代码:
pyspark代码:
time_frame = ["3m","6m","9m","12m","18m","27m","60m","60m_ab"]
variable_name = ["var1", "var2", "var3"....., "var30"]
train_df = sqlContext.sql("select * from someTable")
for var in variable_name:
for tf in range(1,len(time_frame)):
train_df=train_df.withColumn(str(time_frame[tf]+'_'+var), fn.col(str(time_frame[tf]+'_'+var))+fn.col(str(time_frame[tf-1]+'_'+var)))
因此,如您在表上方所见,该表具有不同的列,这些列用于重新创建更多列。但是,Spark / Scala中数据帧的不可变性质遭到了反对,您可以帮我解决一些问题吗?
答案 0 :(得分:0)
这是一种方法,该方法首先使用for-comprehension
生成由列名对组成的元组列表,然后使用foldLeft
遍历该列表以通过{{1 }}:
trainDF
要测试以上代码,只需提供示例输入并创建表withColumn
:
import org.apache.spark.sql.functions._
val timeframes: Seq[String] = ???
val variableNames: Seq[String] = ???
val newCols = for {
vn <- variableNames
tf <- 1 until timeframes.size
} yield (timeframes(tf) + "_" + vn, timeframes(tf - 1) + "_" + vn)
val trainDF = spark.sql("""select * from some_table""")
val resultDF = newCols.foldLeft(trainDF)( (accDF, cs) =>
accDF.withColumn(cs._1, col(cs._1) + col(cs._2))
)
some_table
应该如下所示:
val timeframes = Seq("3m", "6m", "9m")
val variableNames = Seq("var1", "var2")
val df = Seq(
(1, 10, 11, 12, 13, 14, 15),
(2, 20, 21, 22, 23, 24, 25),
(3, 30, 31, 32, 33, 34, 35)
).toDF("id", "3m_var1", "6m_var1", "9m_var1", "3m_var2", "6m_var2", "9m_var2")
df.createOrReplaceTempView("some_table")