如何处理spark-1.6中的分类变量?

时间:2016-05-12 07:13:03

标签: scala apache-spark apache-spark-sql

我正在对我的数据集进行随机森林分类模型。 我使用spark1.6提供的oneHotEncode方法处理了一些分类变量。最后,我得到了许多稀疏向量

我的代码:

     def oneHotEncode(a: String,b:String,c:String,selectedData:DataFrame) : 

      DataFrame = {
      val indexer = new StringIndexer().setInputCol(a).setOutputCol(b).fit(selectedData)
      val indexed = indexer.transform(selectedData)
      val encoder = new OneHotEncoder().setInputCol(b).setOutputCol(c)
      val encoded = encoder.transform(indexed)
      return encoded
}
var data1 = oneHotEncode("ispromoteroom","ispromoteroomIndex","ispromoteroomVec",selectedData)

问题是如何将数据集中的稀疏矢量和其他原始连续变量转换为LabeledPoint数据类型?

1 个答案:

答案 0 :(得分:0)

我已关注此tutorial,这非常有帮助。

def create_labeled_point(line_split):
    # leave_out = [41]
    clean_line_split = line_split[0:41]

    # convert protocol to numeric categorical variable
    try: 
        clean_line_split[1] = protocols.index(clean_line_split[1])
    except:
        clean_line_split[1] = len(protocols)

    # convert service to numeric categorical variable
    try:
        clean_line_split[2] = services.index(clean_line_split[2])
    except:
        clean_line_split[2] = len(services)

    # convert flag to numeric categorical variable
    try:
        clean_line_split[3] = flags.index(clean_line_split[3])
    except:
        clean_line_split[3] = len(flags)

    # convert label to binary label
    attack = 1.0
    if line_split[41]=='normal.':
        attack = 0.0

    return LabeledPoint(attack, array([float(x) for x in clean_line_split]))

training_data = csv_data.map(create_labeled_point)
test_data = test_csv_data.map(create_labeled_point)