Pyspark计算标记点RDD的标签的不同值

时间:2017-10-26 07:13:40

标签: apache-spark pyspark apache-spark-mllib

我在Spark中有一个labeled point的RDD。我想要计算标签的所有不同值。我尝试了一些事情

   
from pyspark.mllib.regression import LabeledPoint

train_data =  sc.parallelize([ LabeledPoint(1.0, [1.0, 0.0, 3.0]),LabeledPoint(2.0, [1.0, 0.0, 3.0]),LabeledPoint(1.0, [1.0, 0.0, 3.0]) ])

train_data.reduceByKey(lambda x : x.label).collect()

但是我得到了

  

TypeError:' LabeledPoint'对象不可迭代

我使用Spark 2.1和python 2.7。谢谢你的帮助。

1 个答案:

答案 0 :(得分:2)

您只需将LabeledPoint转换为键值RDD,然后按键计数:

   
spark.version
# u'2.1.1'

from pyspark.mllib.regression import LabeledPoint

train_data =  sc.parallelize([ LabeledPoint(1.0, [1.0, 0.0, 3.0]),LabeledPoint(2.0, [1.0, 0.0, 3.0]),LabeledPoint(1.0, [1.0, 0.0, 3.0]) ])

dd = train_data.map(lambda x: (x.label, x.features)).countByKey()
dd
# {1.0: 2, 2.0: 1}