使用pyspark进行重量取样

时间:2018-02-01 08:17:11

标签: python apache-spark pyspark sampling

我使用PySpark在spark上有一个不平衡的数据帧。 我想重新取样以使其平衡。 我只在PySpark中找到了示例函数

sample(withReplacement, fraction, seed=None)

但我想用unitvolume的权重对数据帧进行采样 在Python中,我可以像

那样做
df.sample(n,Flase,weights=log(unitvolume))

有什么方法可以使用PySpark做同样的事情吗?

2 个答案:

答案 0 :(得分:1)

Spark提供了分层抽样的工具,但这仅适用于分类数据。你可以试着把它搞砸一下:

from pyspark.ml.feature import Bucketizer
from pyspark.sql.functions import col, log

df_log = df.withColumn("log_unitvolume", log(col("unitvolume"))
splits = ... # A list of splits

bucketizer = Bucketizer(splits=splits, inputCol="log_unitvolume", outputCol="bucketed_log_unitvolume")

df_log_bucketed = bucketizer.transform(df_log)

计算统计数据:

counts = df.groupBy("bucketed_log_unitvolume")
fractions  = ...  # Define fractions from each bucket:

并使用这些进行抽样:

df_log_bucketed.sampleBy("bucketed_log_unitvolume", fractions)

您还可以尝试将log_unitvolume重新调整为[0,1]范围,然后:

from pyspark.sql.functions import rand 

df_log_rescaled.where(col("log_unitvolume_rescaled") < rand())

答案 1 :(得分:0)

一种方法是使用udf制作采样列。此列将随机数乘以您所需的权重。然后我们按采样列排序,并取得前N个。

考虑以下说明性示例:

创建虚拟数据

import numpy as np
import string
import pyspark.sql.functions as f

index = range(100)
weights = [i%26 for i in index]
labels = [string.ascii_uppercase[w] for w in weights]

df = sqlCtx.createDataFrame(
    zip(index, labels, weights),
    ('index', 'label', 'weight')
)

df.show(n=5)
#+-----+-----+------+
#|index|label|weight|
#+-----+-----+------+
#|    0|    A|     0|
#|    1|    B|     1|
#|    2|    C|     2|
#|    3|    D|     3|
#|    4|    E|     4|
#+-----+-----+------+
#only showing top 5 rows

添加采样列

在此示例中,我们希望使用列weight作为权重对DataFrame进行采样。我们使用udf定义numpy.random.random()以生成统一的随机数并乘以权重。然后我们在此列上使用sort()并使用limit()来获取所需的样本数。

N = 10  # the number of samples

def get_sample_value(x):
    return np.random.random() * x

get_sample_value_udf = f.udf(get_sample_value, FloatType())

df_sample = df.withColumn('sampleVal', get_sample_value_udf(f.col('weight')))\
    .sort('sampleVal', ascending=False)\
    .select('index', 'label', 'weight')\
    .limit(N)

<强>结果

正如预期的那样,DataFrame df_sample有10行,它的内容往往在字母表末尾附近有字母(权重较高)。

df_sample.count()
#10

df_sample.show()
#+-----+-----+------+
#|index|label|weight|
#+-----+-----+------+
#|   23|    X|    23|
#|   73|    V|    21|
#|   46|    U|    20|
#|   25|    Z|    25|
#|   19|    T|    19|
#|   96|    S|    18|
#|   75|    X|    23|
#|   48|    W|    22|
#|   51|    Z|    25|
#|   69|    R|    17|
#+-----+-----+------+