从pandas dataframe转换为LabeledPoint RDD

时间:2017-03-22 10:39:33

标签: python pandas apache-spark pyspark apache-spark-mllib

我正在对一个非常简单的数据集进行一些测试,该数据集基本上由数值数据组成。 它可以找到here

我正在使用pandas,numpy和scikit-learn就好了但是当转移到Spark时,我无法以正确的格式设置数据以将其输入到决策树。

我这样做是行不通的:

df = pd.read_csv('http://archive.ics.uci.edu/ml/machine-learning-databases/poker/poker-hand-training-true.data')

raw_data = sc.parallelize(df)

train_dataset = raw_data.map(lambda line: line.split(","))\
                            .map(lambda line:LabeledPoint(line[10], np.array([float(x) for x in line[0:10]])))

在尝试在地图功能中进行访问IndexError: list index out of range时,我不断获得line

当我实际下载文件并更改代码时,我只是设法让它工作:

raw_data = sc.textFile('.../datasets/poker-hand-training.data')

train_dataset = raw_data.map(lambda line: line.split(","))\
                            .map(lambda line:LabeledPoint(line[10], np.array([float(x) for x in line[0:10]])))

如果我不想下载数据集,是否可以使用read_csv直接从pandas数据帧中获取数据?

1 个答案:

答案 0 :(得分:4)

我建议您先将Pandas DataFrame转换为Spark DataFrame。您可以使用sqlContext.createDataFrame方法执行此操作。

df = pd.read_csv('http://archive.ics.uci.edu/ml/machine-learning-databases/poker/poker-hand-training-true.data', names=['S1','C1','S2','C2','S3','C3','S4','C4','S5','C5','class'])
s_df = spark.createDataFrame(df)

现在,您可以使用此数据框来获取训练数据集。

train_dataset  = s_df.rdd.map(lambda x: LabeledPoint(x[10], x[:10])).collect()