如何根据列索引列表从pyspark中的csv文件中选择某些列,然后确定它们的不同长度

时间:2016-04-25 04:00:48

标签: python csv apache-spark pyspark

我在pyspark中有此代码,其中我将index列的值作为list传递。现在我想从csv文件中为这些相应的索引选择列:

def ml_test(input_col_index):

    sc = SparkContext(master='local', appName='test')

    inputData = sc.textFile('hdfs://localhost:/dir1').zipWithIndex().filter(lambda (line, rownum): rownum >= 0).map(lambda (line, rownum): line)

if __name__ == '__main__':

    input_col_index = sys.argv[1] # For example - ['1','2','3','4']

    ml_test(input_col_index)

现在,如果我想从上面的csv文件中选择一组静态或硬编码的列,我可以这样做,但是这里所需列的indexes作为参数传递。另外,我必须计算每个所选列的不同长度,我知道可以通过colmn_1 = input_data.map(lambda x: x[0]).distinct().collect()完成,但是如何对未预先知道的列集进行计算,并根据索引列表确定在运行时传递?

注意:我必须计算不同的列长度,因为我必须将该长度作为参数传递给Pysparks RandomForest算法。

2 个答案:

答案 0 :(得分:1)

您可以使用列表推导。

# given a list of indicies...
indicies = [int(i) for i in input_col_index]

# select only those columns from each row
rdd = rdd.map(lambda x: [x[idx] for idx in indicies])

# for all rows, choose longest columns
longest_per_column = rdd.reduce(
    lambda x, y: [max(a, b, key=len) for a, b in zip(x, y)])

# get lengths of longest columns
print([len(x) for x in longest_per_column])

reduce函数有两个列表,同时循环遍历每个值,并通过选择(对于每一列)较长的一个列来创建一个新列表。

更新:要将长度传递给RandomForest构造函数,您可以执行以下操作:

column_lengths = [len(x) for x in longest_per_column]

model = RandomForest.trainRegressor(
    categoricalFeaturesInfo=dict(enumerate(column_lengths)),
    maxBins=max(column_lengths),
    # ...
)

答案 1 :(得分:0)

我会推荐这个简单的解决方案。

假设我们有以下CSV文件结构[1]:

"TRIP_ID","CALL_TYPE","ORIGIN_CALL","ORIGIN_STAND","TAXI_ID","TIMESTAMP","DAY_TYPE","MISSING_DATA","POLYLINE"
"1372636858620000589","C","","","20000589","1372636858","A","False","[[-8.618643,41.141412],[-8.618499,41.141376]]"

并且您只想选择列:CALL_TYPE, TIMESTAMP, POLYLINE 首先,您需要格式化数据,然后只需拆分并选择所需的列。这很简单:

from pyspark import SparkFiles
raw_data = sc.textFile("data.csv")
callType_days = raw_data.map(lambda x: x.replace('""','"NA"').replace('","', '\n').replace('"','')) \
    .map(lambda x: x.split()) \
    .map(lambda x: (x[1],x[5],x[8]))

callType_days.take(2)

结果将是:

[(u'CALL_TYPE', u'TIMESTAMP', u'POLYLINE'),
 (u'C',
  u'1372636858',
  u'[[-8.618643,41.141412],[-8.618499,41.141376]]')]

之后,使用这样的结构化数据非常容易。

[1]:Taxi Service Trajectory - Prediction Challenge, ECML PKDD 2015 Data Set