如何使用sample()函数执行上采样(py-spark)

时间:2018-11-13 02:57:48

标签: machine-learning pyspark random-forest sampling

我正在研究二进制分类机器学习问题,并且由于目标类变量不平衡,我试图平衡训练集。我正在使用Py-Spark构建模型。

下面是用于平衡数据的代码

(5, 3, Decimal('4'))

上面的代码执行欠采样,但是我认为这可能会导致信息丢失。但是,我不确定如何执行升采样。我也尝试使用如下示例函数:

train_initial, test = new_data.randomSplit([0.7, 0.3], seed = 2018)
train_initial.groupby('label').count().toPandas()
   label   count                                                                
0    0.0  712980
1    1.0    2926
train_new = train_initial.sampleBy('label', fractions={0: 2926./712980, 1: 1.0}).cache()

尽管它在我的数据集中增加了1个计数,但也增加了0个计数并给出以下结果。

train_up = train_initial.sample(True, 10.0, seed = 2018)

有人可以帮我实现py-spark的上采样吗?

非常感谢!!

3 个答案:

答案 0 :(得分:2)

问题是您对整个数据帧进行了过度采样。您应该过滤两个类中的数据

df_class_0 = df_train[df_train['label'] == 0]
df_class_1 = df_train[df_train['label'] == 1]
df_class_1_over = df_class_1.sample(count_class_0, replace=True)
df_test_over = pd.concat([df_class_0, df_class_1_over], axis=0)

该示例来自:https://www.kaggle.com/rafjaa/resampling-strategies-for-imbalanced-datasets

请注意,有更好的方法来执行过采样(例如SMOTE)

答案 1 :(得分:1)

在这里营救我可能已经很晚了。但这是我的建议:

第1步。仅针对标签= 1采样

train_1= train_initial.where(col('label')==1).sample(True, 10.0, seed = 2018)

第2步。将此数据与label = 0数据合并

train_0=train_initial.where(col('label')==0)
train_final = train_0.union(train_1)

PS:请使用导入

from pyspark.sql.functions import col

答案 2 :(得分:1)

对于任何试图在pyspark中的不平衡数据集上进行随机过采样的人。以下代码将帮助您入门(在此代码段中,0是mayority类,而1是要过采样的类):

df_a = df.filter(df['label'] == 0)
df_b = df.filter(df['label'] == 1)

a_count = df_a.count()
b_count = df_b.count() 
ratio = a_count / b_count

df_b_overampled = df_b.sample(withReplacement=True, fraction=ratio, seed=1)
df = df_a.unionAll(df_b_oversampled)