spark.read.format('libsvm')不适用于python

时间:2019-12-09 07:38:48

标签: pyspark jupyter pyspark-dataframes

我正在学习PYSPARK,遇到无法解决的问题。我按照此视频操作,从PYSPARK文档中复制代码以加载数据以进行线性回归。我从文档中获得的代码是spark.read.format('libsvm')。load('file.txt')。我在此之前创建了一个火花数据框。当我在Jupyter笔记本中运行此代码时,它一直在给我一些Java错误,而该视频中的那个家伙所做的事情与我完全相同,而他没有遇到此错误。有人可以帮我解决这个问题吗?
非常感谢!

1 个答案:

答案 0 :(得分:0)

您可以使用此自定义函数读取libsvm文件。

from pyspark.sql import Row
from pyspark.ml.linalg import SparseVector

def read_libsvm(filepath, spark_session):
    '''
    A utility function that takes in a libsvm file and turn it to a pyspark dataframe.

    Args:
        filepath (str): The file path to the data file.
        spark_session (object): The SparkSession object to create dataframe.

    Returns:
        A pyspark dataframe that contains the data loaded.
    '''

    with open(filepath, 'r') as f:
        raw_data = [x.split() for x in f.readlines()]

    outcome = [int(x[0]) for x in raw_data]

    index_value_dict = list()
    for row in raw_data:
        index_value_dict.append(dict([(int(x.split(':')[0]), float(x.split(':')[1]))
                                       for x in row[1:]]))

    max_idx = max([max(x.keys()) for x in index_value_dict])
    rows = [
        Row(
            label=outcome[i],
            feat_vector=SparseVector(max_idx + 1, index_value_dict[i])
        )
        for i in range(len(index_value_dict))
    ]
    df = spark_session.createDataFrame(rows)
    return df

用法:

my_data = read_libsvm(filepath="sample_libsvm_data.txt", spark_session=spark)