使用Spark高阶函数时如何返回案例类?

时间:2019-11-24 13:52:17

标签: scala apache-spark

我试图使用Spark transform函数,以将数组的项从ClassA类型转换为ClassB,如下所示:

case class ClassA(a: String, b: String, c: String)
case class ClassB(a: String, b: String)

val a1 = ClassA("a1", "b1", "c1")
val a2 = ClassA("a2", "b2", "c2")

val df = Seq(
(Seq(a1, a2))
).toDF("ClassA")

df.withColumn("ClassB", expr("transform(ClassA, c -> ClassB(c.a, c.b))")).show(false)

尽管以上代码失败,并显示以下消息:

  

org.apache.spark.sql.AnalysisException:未定义的函数:'ClassB'。   此功能既不是注册的临时功能,也不是   在数据库“默认”中注册的永久功能。

完成这项工作的唯一方法是通过struct,如下所示:

df.withColumn("ClassB", expr("transform(ClassA, c -> struct(c.a as string, c.b as string))")).show(false)

// +----------------------------+--------------------+
// |ClassA                      |ClassB              |
// +----------------------------+--------------------+
// |[[a1, b1, c1], [a2, b2, c2]]|[[a1, b1], [a2, b2]]|
// +----------------------------+--------------------+

所以问题是使用transform时是否有任何方法可以返回case类而不是struct?

1 个答案:

答案 0 :(得分:3)

n_pokemons = 1000 n_types = 18 n_min_weaknesses = 1 # number of minimal and maximal weaknesses for each Pokemon n_max_weaknesses = 4 表达式是关系表达式,对案例类import numpy as np # Generate pokemons name_arr = np.array(['pikabra_{}'.format(i) for i in range(n_pokemons)]) # Random stats bst_arr = np.random.random(n_pokemons) * 100 # Random weaknesses weakness_array = np.zeros((n_pokemons, n_types), dtype=bool) # bool array indicating the weak types of each pokemon for i in range(n_pokemons): rnd_weaknesses = np.random.choice(np.arange(n_types), np.random.randint(n_min_weaknesses, n_max_weaknesses+1)) weakness_array[i, rnd_weaknesses] = True # Remove unnecessary pokemons i = 0 while i < n_pokemons: j = i + 1 while j < n_pokemons: del_idx = None combined_weaknesses = np.logical_or(weakness_array[i], weakness_array[j]) if np.all(weakness_array[i] == weakness_array[j]): if bst_arr[j] < bst_arr[i]: del_idx = i else: del_idx = j elif np.all(combined_weaknesses == weakness_array[i]) and bst_arr[j] < bst_arr[i]: del_idx = i elif np.all(combined_weaknesses == weakness_array[j]) and bst_arr[i] < bst_arr[j]: del_idx = j if del_idx is not None: name_arr = np.delete(name_arr, del_idx, axis=0) bst_arr = np.delete(bst_arr, del_idx, axis=0) weakness_array = np.delete(weakness_array, del_idx, axis=0) n_pokemons -= 1 if del_idx == i: i -= 1 break else: j -= 1 j += 1 i += 1 print(n_pokemons) transform一无所知。 拥有AFAIK的唯一方法是注册UDF,以便可以使用结构(或注入函数),但还必须处理“ ClassA”编码值,而不是ClassA(SparkSQL完全是关于编码的:))就像这样:

ClassB

旁注:由于转换读取的是列而不是类型,因此命名列“ ClassA”可能会造成混淆。