将DenseMatrices列拆分为单独的行(每行都有一个向量)

时间:2019-05-21 21:39:20

标签: pyspark pyspark-sql

我在pyspark.sql.DataFrame中有一个matrix类型的列。

此列中的每个单元格都是形状为DenseMatrix的{​​{1}}

即单元格之间的行数会有所不同,但列数始终为268。

我想拆分此列所有矩阵中的所有行,以使生成的数据帧中的每一行都是一个向量。

例如,我将如何转换以下内容:

(numRows, 268)

类似:

|groups|windows                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    |
+------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|1     |0.0                 0.0                 1.383419689119171   ... (268 total)
0.0                 1.0308333333333333  1.0                 ...
0.0                 1.0714285714285714  1.0                 ...
0.0                 1.241112828438949   1.0                 ...
0.0                 1.01                1.0212464589235128  ...
0.0                 0.0                 1.0303994011640099  ...
0.0                 1.0310714270488266  0.0                 ...
0.0                 1.7106598984771573  0.0                 ...
0.0                 1.0                 1.7657142857142856  ...
0.0                 1.3483709273182958  1.7071428571428573  ...
0.0                 1.4608788853161845  1.2461538461538462  ...
0.0                 1.0                 0.0                 ...
0.0                 1.0                 0.0                 ...
1.6600496277915633  1.0                 1.0                 ...
1.3537936913895994  1.812121212121212   1.2403100775193798  ...
0.0                 1.6721590909090909  1.0                 ...
1.6479591836734695  0.0                 0.0                 ...
0.0                 1.075               0.0                 ...
1.2246376811594204  0.0                 0.0                 ...
1.0                 1.659994867847062   1.0                 ...
1.0                 0.0                 1.5507936E9         ...
0.0                 1.0                 0.0                 ...
1.6974358974358972  0.0                 0.0                 ...|
|2     |0.0                 0.0                 1.4455958549222798  ... (268 total)
0.0                 1.02875             1.0                 ...
0.0                 1.0714285714285714  1.0                 ...
0.0                 1.2179289026275115  1.0                 ...
0.0                 1.01                1.0191218130311614  ...
0.0                 0.0                 1.028490828331661   ...
0.0                 1.028214284187194   0.0                 ...
0.0                 1.7309644670050761  0.0                 ...
0.0                 1.0                 1.7885714285714287  ...
0.0                 1.3525480367585632  1.7285714285714286  ...
0.0                 1.4683815648445875  1.2153846153846155  ...
0.0                 1.0                 0.0                 ...
0.0                 1.0                 0.0                 ...
1.6972704714640199  1.0                 1.0                 ...
1.3580562659846547  1.8242424242424242  1.2170542635658914  ...
0.0                 1.6971590909090908  1.0                 ...
1.663265306122449   0.0                 0.0                 ...
0.0                 1.0964285714285715  0.0                 ...
1.2028985507246377  0.0                 0.0                 ...
1.0                 1.6782140107775212  1.0                 ...
1.0                 0.0                 1.5507936E9         ...
0.0                 1.0                 0.0                 ...
1.7282051282051283  0.0                 0.0                 ...|
+------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
only showing top 2 rows

任何帮助将不胜感激!

EDIT_1

我将重申从DenseMatrix开始。通过使用|groups|windows +------+------------------------------------------------------------------------- |1 |0.0, 0.0, 1.383419689119171, ... (268 total) +------+----------------------------------------------------------------------- |1 |0.0, 1.0308333333333333, 1.0, ... +------+----------------------------------------------------------------------- |1 |0.0, 1.0714285714285714, 1.0, ... +------+----------------------------------------------------------------------- |1 |0.0, 1.241112828438949, 1.0, ... +------+----------------------------------------------------------------------- |1 |0.0, 1.01, 1.0212464589235128, ... +------+----------------------------------------------------------------------- |1 |0.0, 0.0, 1.0303994011640099, ... +------+----------------------------------------------------------------------- |1 |0.0, 1.0310714270488266, 0.0, ... +------+----------------------------------------------------------------------- |1 |0.0, 1.7106598984771573, 0.0, ... +------+----------------------------------------------------------------------- |1 |0.0, 1.0, 1.7657142857142856, ... +------+----------------------------------------------------------------------- |1 |0.0, 1.3483709273182958, 1.7071428571428573, ... +------+----------------------------------------------------------------------- |1 |0.0, 1.4608788853161845, 1.2461538461538462, ... +------+----------------------------------------------------------------------- |1 |0.0, 1.0, 0.0, ... +------+----------------------------------------------------------------------- |1 |0.0, 1.0, 0.0, ... +------+----------------------------------------------------------------------- |1 |1.6600496277915633, 1.0, 1.0, ... +------+----------------------------------------------------------------------- |1 |1.3537936913895994, 1.812121212121212, 1.2403100775193798, ... +------+----------------------------------------------------------------------- |1 |0.0, 1.6721590909090909, 1.0, ... +------+----------------------------------------------------------------------- |1 |1.6479591836734695, 0.0, 0.0, ... +------+----------------------------------------------------------------------- |1 |0.0, 1.075, 0.0, ... +------+----------------------------------------------------------------------- |1 |1.2246376811594204, 0.0, 0.0, ... +------+----------------------------------------------------------------------- |1 |1.0, 1.659994867847062, 1.0, ... +------+----------------------------------------------------------------------- |1 |1.0, 0.0, 1.5507936E9, ... +------+----------------------------------------------------------------------- |1 |0.0, 1.0, 0.0, ... +------+----------------------------------------------------------------------- |1 |1.6974358974358972, 0.0, 0.0, ...| +------+----------------------------------------------------------------------- |2 |0.0, 0.0, 1.4455958549222798, ... (268 total) +------+----------------------------------------------------------------------- |2 |0.0, 1.02875, 1.0, ... +------+----------------------------------------------------------------------- |2 |0.0, 1.0714285714285714, 1.0, ... +------+----------------------------------------------------------------------- |2 |0.0, 1.2179289026275115, 1.0, ... +------+----------------------------------------------------------------------- |2 |0.0, 1.01, 1.0191218130311614, ... +------+----------------------------------------------------------------------- |2 |0.0, 0.0, 1.028490828331661, ... +------+----------------------------------------------------------------------- |2 |0.0, 1.028214284187194, 0.0, ... +------+----------------------------------------------------------------------- |2 |0.0, 1.7309644670050761, 0.0, ... +------+----------------------------------------------------------------------- |2 |0.0, 1.0, 1.7885714285714287, ... +------+----------------------------------------------------------------------- |2 |0.0, 1.3525480367585632, 1.7285714285714286, ... +------+----------------------------------------------------------------------- |2 |0.0, 1.4683815648445875, 1.2153846153846155, ... +------+----------------------------------------------------------------------- |2 |0.0, 1.0, 0.0, ... +------+----------------------------------------------------------------------- |2 |0.0, 1.0, 0.0, ... +------+----------------------------------------------------------------------- |2 |1.6972704714640199, 1.0, 1.0, ... +------+----------------------------------------------------------------------- |2 |1.3580562659846547, 1.8242424242424242, 1.2170542635658914, ... +------+----------------------------------------------------------------------- |2 |0.0, 1.6971590909090908, 1.0, ... +------+----------------------------------------------------------------------- |2 |1.663265306122449, 0.0, 0.0, ... +------+----------------------------------------------------------------------- |2 |0.0, 1.0964285714285715, 0.0, ... +------+----------------------------------------------------------------------- |2 |1.2028985507246377, 0.0, 0.0, ... +------+----------------------------------------------------------------------- |2 |1.0, 1.6782140107775212, 1.0, ... +------+----------------------------------------------------------------------- |2 |1.0, 0.0, 1.5507936E9, ... +------+----------------------------------------------------------------------- |2 |0.0, 1.0, 0.0, ... +------+----------------------------------------------------------------------- |2 |1.7282051282051283, 0.0, 0.0, ...| +------+----------------------------------------------------------------------- +------+----------------------------------------------------------------------- only showing top 2 rows 函数,我也能够“解决”我的问题,但我必须:

1)将explode列强制转换为字符串:

windows

2)将字符串解析为字符串数组(每个字符串代表一个向量)

def stringify_matrices(x):
    arr = x.toArray()
    l = arr.tolist()
    return l

stringify_matrices_udf = udf(lambda y: stringify_matrices(y),) 

expanded = \
    extracted.withColumn('expanded',
                        stringify_matrices_udf('windows')
                        )

3)def parse_matrices(x): from ast import literal_eval t = literal_eval(str(x)) str_arr = [str(a) for a in t] return str_arr parse_matrices_udf = udf(lambda y: parse_matrices(y), ArrayType(StringType())) parsed = \ expanded.withColumn('parsed', parse_matrices_udf('expanded') )

explode

4)投射到parsed = parsed.withColumn('exploded', explode(parsed.parsed)).select('groups', 'exploded')

ArrayType(DoubleType()))

以上方法有效,但我觉得似乎有更好的方法来解决这个问题。

EDIT_2 @mayanak农业 感谢您的回答! 我想作为回应,我会问:

如何从def convert_to_double(x): str_arr = x.replace('[','').replace(']','').split(',') flt_arr = [float(a) for a in str_arr] return flt_arr convert_to_double_udf = udf(lambda y: convert_to_double(y), ArrayType(DoubleType())) converted = parsed.withColumn('feature_vector', convert_to_double_udf('exploded')) 列进行转换: 例如

DenseMatrix
dm_df = sqlContext.createDataFrame([
        (1, 
         DenseMatrix(numRows=3, numCols=4, values=[2,4,2,5,30,4,2,5,30,4,2,5], isTransposed=True)),
        (2, 
         DenseMatrix(numRows=2, numCols=4, values=[2,1,3,7,2,4,2,9], isTransposed=True)),
        (3, 
         DenseMatrix(numRows=4, numCols=4, values=[2,4,2,5,2,4,2,5,2,1,3,7,2,1,3,7], isTransposed=True))],
        ['groups', 'windows'])
dm_df.show()

到一列2D浮点数(如您的示例所示):

+------+-----------------------------------------------------------------------------------+
|groups|windows                                                                            |
+------+-----------------------------------------------------------------------------------+
|1     |2.0   4.0  2.0  5.0  
30.0  4.0  2.0  5.0  
30.0  4.0  2.0  5.0                    |
|2     |2.0  1.0  3.0  7.0  
2.0  4.0  2.0  9.0                                            |
|3     |2.0  4.0  2.0  5.0  
2.0  4.0  2.0  5.0  
2.0  1.0  3.0  7.0  
2.0  1.0  3.0  7.0  |
+------+-----------------------------------------------------------------------------------+
arr_df = sqlContext.createDataFrame([
        (1, [[2,4,2,5],[30,4,2,5],[30,4,2,5]]),
        (2, [[2,1,3,7],[2,4,2,9]]),
        (3, [[2,4,2,5],[2,4,2,5],[2,1,3,7],[2,1,3,7]])],
        ['groups', 'windows'])
arr_df.show()

再次感谢!

1 个答案:

答案 0 :(得分:0)

我无法创建您的确切示例数据框。因此,我创建了它的较小版本。让我知道是否需要任何更改。

import pyspark.sql.functions as F

df = sql.createDataFrame([
        (1, [[2,4,2,5],[30,4,2,5],[30,4,2,5]]),
        (2, [[2,1,3,7],[2,4,2,9]]),
        (3, [[2,4,2,5,3],[2,4,2,5],[2,1,3,7],[2,1,3,7]])],
        ['groups', 'windows'])

通过在'windows'列上爆炸,我们得到所需的结果。

df = df.select(['groups', F.explode(F.col('windows')).alias('windows')])

这给出了输出,

+------+---------------+
|groups|        windows|
+------+---------------+
|     1|   [2, 4, 2, 5]|
|     1|  [30, 4, 2, 5]|
|     1|  [30, 4, 2, 5]|
|     2|   [2, 1, 3, 7]|
|     2|   [2, 4, 2, 9]|
|     3|[2, 4, 2, 5, 3]|
|     3|   [2, 4, 2, 5]|
|     3|   [2, 1, 3, 7]|
|     3|   [2, 1, 3, 7]|
+------+---------------+

编辑:

将其转换为列表后,我直接能够将其爆炸。无需转换为字符串。只需在stringify_matrices_udf中输入数据类型即可。

import pyspark.sql.functions as F

from pyspark.sql.types import *

def stringify_matrices(x):
    arr = x.toArray()
    l = arr.tolist()
    print l
    return l


df = sql.createDataFrame([
        (1, 
         DenseMatrix(numRows=3, numCols=4, values=[2,4,2,5,30,4,2,5,30,4,2,5], isTransposed=True)),
        (2, 
         DenseMatrix(numRows=2, numCols=4, values=[2,1,3,7,2,4,2,9], isTransposed=True)),
        (3, 
         DenseMatrix(numRows=4, numCols=4, values=[2,4,2,5,2,4,2,5,2,1,3,7,2,1,3,7], isTransposed=True))],
        ['groups', 'windows'])

stringify_matrices_udf = F.udf(lambda y: stringify_matrices(y),ArrayType(ArrayType(FloatType()))) 

df = \
    df.withColumn('expanded',
                        stringify_matrices_udf('windows')
                        ) \
      .select(['groups', F.explode(F.col('expanded')).alias('windows')])

df.show()

这给了

+------+--------------------+
|groups|             windows|
+------+--------------------+
|     1|[2.0, 4.0, 2.0, 5.0]|
|     1|[30.0, 4.0, 2.0, ...|
|     1|[30.0, 4.0, 2.0, ...|
|     2|[2.0, 1.0, 3.0, 7.0]|
|     2|[2.0, 4.0, 2.0, 9.0]|
|     3|[2.0, 4.0, 2.0, 5.0]|
|     3|[2.0, 4.0, 2.0, 5.0]|
|     3|[2.0, 1.0, 3.0, 7.0]|
|     3|[2.0, 1.0, 3.0, 7.0]|
+------+--------------------+