用一个(布尔)数组列子集一个数组列

时间:2019-04-22 16:06:25

标签: apache-spark pyspark apache-spark-sql pyspark-sql

我有一个这样的数据框(在Pyspark 2.3.1中):

from pyspark.sql import Row

my_data = spark.createDataFrame([
  Row(a=[9, 3, 4], b=['a', 'b', 'c'], mask=[True, False, False]),
  Row(a=[7, 2, 6, 4], b=['w', 'x', 'y', 'z'], mask=[True, False, True, False])
])
my_data.show(truncate=False)
#+------------+------------+--------------------------+
#|a           |b           |mask                      |
#+------------+------------+--------------------------+
#|[9, 3, 4]   |[a, b, c]   |[true, false, false]      |
#|[7, 2, 6, 4]|[w, x, y, z]|[true, false, true, false]|
#+------------+------------+--------------------------+

现在,我想使用mask列来对ab列进行子集化:

my_desired_output = spark.createDataFrame([
  Row(a=[9], b=['a']),
  Row(a=[7, 6], b=['w', 'y'])
])
my_desired_output.show(truncate=False)
#+------+------+
#|a     |b     |
#+------+------+
#|[9]   |[a]   |
#|[7, 6]|[w, y]|
#+------+------+

实现此目标的“惯用”方法是什么?我目前使用的解决方案包括map-遍历基础RDD并使用Numpy进行子集设置,这似乎很不雅:

import numpy as np

def subset_with_mask(row):
    mask = np.asarray(row.mask)
    a_masked = np.asarray(row.a)[mask].tolist()
    b_masked = np.asarray(row.b)[mask].tolist()
    return Row(a=a_masked, b=b_masked)

my_desired_output = spark.createDataFrame(my_data.rdd.map(subset_with_mask))

这是最好的方法吗,还是我可以使用Spark SQL工具做得更好(不太冗长和/或更有效)?

3 个答案:

答案 0 :(得分:2)

一种选择是使用UDF,您可以选择根据数组中的数据类型对其进行专门化:

import java.util.Arrays;
import java.util.Comparator;

public class PartNumberQuantityDetailer {
// initialize a two dimensional array
static Integer[][] itemIdAndQty = new Integer[5][2];

public static void main(String[] args) {
    // initialize array values
    itemIdAndQty[0][0] = 1234;
    itemIdAndQty[0][1] = 46;
    itemIdAndQty[1][0] = 5443;
    itemIdAndQty[1][1] = 564;
    itemIdAndQty[2][0] = 362;
    itemIdAndQty[2][1] = 24;
    itemIdAndQty[3][0] = 6742;
    itemIdAndQty[3][1] = 825;
    itemIdAndQty[4][0] = 347;
    itemIdAndQty[4][1] = 549;
    System.out.println("Before sorting");
    // show the contents of array
    displayArray();
    // sort the array on item id(first column)
    Arrays.sort(itemIdAndQty, new Comparator<Integer[]>() {
        @Override
                    //arguments to this method represent the arrays to be sorted   
        public int compare(Integer[] o1, Integer[] o2) {
                            //get the item ids which are at index 0 of the array
            Integer itemIdOne = o1[0];
            Integer itemIdTwo = o2[0];
            // sort on item id
            return itemIdOne.compareTo(itemIdTwo);
        }
    });
    // display array after sort
    System.out.println("After sorting on item id in ascending order");
    displayArray();
    // sort array on quantity(second column)
    Arrays.sort(itemIdAndQty, new Comparator<Integer[]>() {
        @Override
        public int compare(Integer[] o1, Integer[] o2) {
            Integer quantityOne = o1[1];
            Integer quantityTwo = o2[1];
            // reverse sort on quantity
            return quantityOne.compareTo(quantityTwo);
        }
    });
    // display array after sort
    System.out.println("After sorting on quantity in ascending order");
    displayArray();

}

private static void displayArray() {
    System.out.println("-------------------------------------");
    System.out.println("Item id\t\tQuantity");
    for (int i = 0; i < itemIdAndQty.length; i++) {
        Integer[] itemRecord = itemIdAndQty[i];
        System.out.println(itemRecord[0] + "\t\t" + itemRecord[1]);
    }
    System.out.println("-------------------------------------");
}

答案 1 :(得分:1)

上一个答案中提到的

UDF可能是在Spark 2.4中添加数组函数之前的方法。为了完整起见,这是2.4之前的“纯SQL”实现。

from pyspark.sql.functions import *

df = my_data.withColumn("row", monotonically_increasing_id())

df1 = df.select("row", posexplode("a").alias("pos", "a"))
df2 = df.select("row", posexplode("b").alias("pos", "b"))
df3 = df.select("row", posexplode("mask").alias("pos", "mask"))

df1\
    .join(df2, ["row", "pos"])\
    .join(df3, ["row", "pos"])\
    .filter("mask")\
    .groupBy("row")\
    .agg(collect_list("a").alias("a"), collect_list("b").alias("b"))\
    .select("a", "b")\
    .show()

输出:

+------+------+
|     a|     b|
+------+------+
|[7, 6]|[w, y]|
|   [9]|   [a]|
+------+------+

答案 2 :(得分:0)

这是使用2个UDF压缩和解压缩列表的另一种方法:

from pyspark.sql.types import ArrayType, StructType, StructField, StringType
from pyspark.sql.functions import udf, col, lit

zip_schema = ArrayType(StructType((StructField("a", StringType()), StructField("b", StringType()))))  
unzip_schema = ArrayType(StringType())

zip_udf = udf(my_zip, zip_schema)
unzip_udf = udf(my_unzip, unzip_schema)

df = my_data.withColumn("zipped", zip_udf(col("a"), col("b"), col("mask")))
       .withColumn("a", unzip_udf(col("zipped"), lit(0)))
       .withColumn("b", unzip_udf(col("zipped"), lit(1)))
       .drop("zipped", "mask")

def my_unzip(zipped, indx):
    return  [str(x[indx]) for x in zipped]

def my_zip(a, b, mask):
    return [(x[0], x[1]) for x in zip(a,b,mask) if x[2]]

my_zip 负责根据掩码过滤数据并创建(cola,colb)元组,这也是返回列表的一项。

my_unzip 将从使用 my_zip 创建的数据中提取特定indx的数据。

输出:

+------+------+
|     a|     b|
+------+------+
|   [9]|   [a]|
|[7, 6]|[w, y]|
+------+------+