将数据帧拟合到randomForest pyspark中

时间:2017-06-02 09:35:01

标签: python apache-spark pyspark apache-spark-ml

我的DataFrame看起来像这样:

+--------------------+------------------+
|            features|           labels |
+--------------------+------------------+
|[-0.38475, 0.568...]|          label1  |
|[0.645734, 0.699...]|          label2  |
|     .....          |          ...     |
+--------------------+------------------+

两列都是String类型(StringType()),我想把它放到spark ml randomForest中。为此,我需要将features列转换为包含浮点数的向量。有没有人知道怎么做?

1 个答案:

答案 0 :(得分:6)

如果您使用 Spark 2.x ,我相信这就是您所需要的:

from pyspark.sql.functions import udf
from pyspark.mllib.linalg import Vectors
from pyspark.ml.linalg import VectorUDT
from pyspark.ml.feature import StringIndexer

df = spark.createDataFrame([("[-0.38475, 0.568]", "label1"), ("[0.645734, 0.699]", "label2")], ("features", "label"))

def parse(s):
  try:
    return Vectors.parse(s).asML()
  except:
    return None

parse_ = udf(parse, VectorUDT())

parsed = df.withColumn("features", parse_("features"))

indexer = StringIndexer(inputCol="label", outputCol="label_indexed")

indexer.fit(parsed).transform(parsed).show()
## +----------------+------+-------------+
## |        features| label|label_indexed|
## +----------------+------+-------------+
## |[-0.38475,0.568]|label1|          0.0|
## |[0.645734,0.699]|label2|          1.0|
## +----------------+------+-------------+

使用 Spark 1.6 ,它并没有太大区别:

from pyspark.sql.functions import udf
from pyspark.ml.feature import StringIndexer
from pyspark.mllib.linalg import Vectors, VectorUDT

df = sqlContext.createDataFrame([("[-0.38475, 0.568]", "label1"), ("[0.645734, 0.699]", "label2")], ("features", "label"))

parse_ = udf(Vectors.parse, VectorUDT())

parsed = df.withColumn("features", parse_("features"))

indexer = StringIndexer(inputCol="label", outputCol="label_indexed")

indexer.fit(parsed).transform(parsed).show()
## +----------------+------+-------------+
## |        features| label|label_indexed|
## +----------------+------+-------------+
## |[-0.38475,0.568]|label1|          0.0|
## |[0.645734,0.699]|label2|          1.0|
## +----------------+------+-------------+

Vectors具有parse功能,可以帮助您实现您的目标。