如何在Pyspark中的VectorAssembler中使用字符串变量

时间:2017-09-20 21:25:39

标签: pyspark random-forest

我想在Pyspark上运行Random Forests算法。在Pyspark documentation中提到VectorAssembler只接受数字或布尔数据类型。那么,如果我的数据包含Stringtype变量,比如城市的名称,我是否应该对它们进行单热编码,以便进一步进行随机森林分类/回归?

以下是我一直在尝试的代码,输入文件是here

train=sqlContext.read.format('com.databricks.spark.csv').options(header='true').load('filename')
drop_list = ["Country", "Carrier", "TrafficType","Device","Browser","OS","Fraud","ConversionPayOut"]
from pyspark.sql.types import DoubleType
train = train.withColumn("ConversionPayOut", train["ConversionPayOut"].cast("double"))#only this variable is actually double, rest of them are strings
junk = train.select([column for column in train.columns if column in drop_list])
transformed = assembler.transform(junk)

我不断得到IllegalArgumentException: u'Data type StringType is not supported.'

的错误

P.S。:提出基本问题的道歉。我来自R背景。在R中,当我们进行随机森林时,不需要将分类变量转换为数字变量。

2 个答案:

答案 0 :(得分:1)

是的,您应该使用StringIndexer,可能与OneHotEncoder一起使用。您可以在链接的文档中找到有关这两者的更多信息。

答案 1 :(得分:0)

Following is the example -
Schema
 |-- age: integer (nullable = true)
 |-- workclass: string (nullable = true)
 |-- fnlwgt: double (nullable = true)
 |-- education: string (nullable = true)
 |-- education-num: double (nullable = true)
 |-- marital-status: string (nullable = true)
 |-- occupation: string (nullable = true)
 |-- relationship: string (nullable = true)
 |-- race: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- capital-gain: double (nullable = true)
 |-- capital-loss: double (nullable = true)
 |-- hours-per-week: double (nullable = true)
 |-- native-country: string (nullable = true)
 |-- income: string (nullable = true)

        // Deal with Categorical Columns
        // Transform string type columns to string indexer 
        val workclassIndexer = new StringIndexer().setInputCol("workclass").setOutputCol("workclassIndex")
        val educationIndexer = new StringIndexer().setInputCol("education").setOutputCol("educationIndex")
        val maritalStatusIndexer = new StringIndexer().setInputCol("marital-status").setOutputCol("maritalStatusIndex")
        val occupationIndexer = new StringIndexer().setInputCol("occupation").setOutputCol("occupationIndex")
        val relationshipIndexer = new StringIndexer().setInputCol("relationship").setOutputCol("relationshipIndex")
        val raceIndexer = new StringIndexer().setInputCol("race").setOutputCol("raceIndex")
        val sexIndexer = new StringIndexer().setInputCol("sex").setOutputCol("sexIndex")
        val nativeCountryIndexer = new StringIndexer().setInputCol("native-country").setOutputCol("nativeCountryIndex")
        val incomeIndexer = new StringIndexer().setInputCol("income").setOutputCol("incomeIndex")

        // Transform string type columns to string indexer 
        val workclassEncoder = new OneHotEncoder().setInputCol("workclassIndex").setOutputCol("workclassVec")
        val educationEncoder = new OneHotEncoder().setInputCol("educationIndex").setOutputCol("educationVec")
        val maritalStatusEncoder = new OneHotEncoder().setInputCol("maritalStatusIndex").setOutputCol("maritalVec")
        val occupationEncoder = new OneHotEncoder().setInputCol("occupationIndex").setOutputCol("occupationVec")
        val relationshipEncoder = new OneHotEncoder().setInputCol("relationshipIndex").setOutputCol("relationshipVec")
        val raceEncoder = new OneHotEncoder().setInputCol("raceIndex").setOutputCol("raceVec")
        val sexEncoder = new OneHotEncoder().setInputCol("sexIndex").setOutputCol("sexVec")
        val nativeCountryEncoder = new OneHotEncoder().setInputCol("nativeCountryIndex").setOutputCol("nativeCountryVec")
        val incomeEncoder = new StringIndexer().setInputCol("incomeIndex").setOutputCol("label")

    // Assemble everything together to be ("label","features") format
        val assembler = (new VectorAssembler()
          .setInputCols(Array("workclassVec", "fnlwgt", "educationVec", "education-num", "maritalVec", "occupationVec", "relationshipVec", "raceVec", "sexVec", "capital-gain", "capital-loss", "hours-per-week", "nativeCountryVec"))
          .setOutputCol("features"))

 ///////////////////////////////
    // Set Up the Pipeline ///////
    /////////////////////////////
    import org.apache.spark.ml.Pipeline

    val lr = new LogisticRegression()

    val pipeline = new Pipeline().setStages(Array(workclassIndexer, educationIndexer, maritalStatusIndexer, occupationIndexer, relationshipIndexer, raceIndexer, sexIndexer, nativeCountryIndexer, incomeIndexer, workclassEncoder, educationEncoder, maritalStatusEncoder, occupationEncoder, relationshipEncoder, raceEncoder, sexEncoder, nativeCountryEncoder, incomeEncoder, assembler, lr))

    // Fit the pipeline to training documents.
    val model = pipeline.fit(training)