我在pyspark.sql.DataFrame
中有一个matrix
类型的列。
此列中的每个单元格都是形状为DenseMatrix
的{{1}}
即单元格之间的行数会有所不同,但列数始终为268。
我想拆分此列所有矩阵中的所有行,以使生成的数据帧中的每一行都是一个向量。
例如,我将如何转换以下内容:
(numRows, 268)
类似:
|groups|windows |

|1 |0.0 0.0 1.383419689119171 ... (268 total)
0.0 1.0308333333333333 1.0 ...
0.0 1.0714285714285714 1.0 ...
0.0 1.241112828438949 1.0 ...
0.0 1.01 1.0212464589235128 ...
0.0 0.0 1.0303994011640099 ...
0.0 1.0310714270488266 0.0 ...
0.0 1.7106598984771573 0.0 ...
0.0 1.0 1.7657142857142856 ...
0.0 1.3483709273182958 1.7071428571428573 ...
0.0 1.4608788853161845 1.2461538461538462 ...
0.0 1.0 0.0 ...
0.0 1.0 0.0 ...
1.6600496277915633 1.0 1.0 ...
1.3537936913895994 1.812121212121212 1.2403100775193798 ...
0.0 1.6721590909090909 1.0 ...
1.6479591836734695 0.0 0.0 ...
0.0 1.075 0.0 ...
1.2246376811594204 0.0 0.0 ...
1.0 1.659994867847062 1.0 ...
1.0 0.0 1.5507936E9 ...
0.0 1.0 0.0 ...
1.6974358974358972 0.0 0.0 ...|
|2 |0.0 0.0 1.4455958549222798 ... (268 total)
0.0 1.02875 1.0 ...
0.0 1.0714285714285714 1.0 ...
0.0 1.2179289026275115 1.0 ...
0.0 1.01 1.0191218130311614 ...
0.0 0.0 1.028490828331661 ...
0.0 1.028214284187194 0.0 ...
0.0 1.7309644670050761 0.0 ...
0.0 1.0 1.7885714285714287 ...
0.0 1.3525480367585632 1.7285714285714286 ...
0.0 1.4683815648445875 1.2153846153846155 ...
0.0 1.0 0.0 ...
0.0 1.0 0.0 ...
1.6972704714640199 1.0 1.0 ...
1.3580562659846547 1.8242424242424242 1.2170542635658914 ...
0.0 1.6971590909090908 1.0 ...
1.663265306122449 0.0 0.0 ...
0.0 1.0964285714285715 0.0 ...
1.2028985507246377 0.0 0.0 ...
1.0 1.6782140107775212 1.0 ...
1.0 0.0 1.5507936E9 ...
0.0 1.0 0.0 ...
1.7282051282051283 0.0 0.0 ...|

only showing top 2 rows
任何帮助将不胜感激!
EDIT_1
我将重申从DenseMatrix
开始。通过使用|groups|windows
+------+-------------------------------------------------------------------------
|1 |0.0, 0.0, 1.383419689119171, ... (268 total)
+------+-----------------------------------------------------------------------
|1 |0.0, 1.0308333333333333, 1.0, ...
+------+-----------------------------------------------------------------------
|1 |0.0, 1.0714285714285714, 1.0, ...
+------+-----------------------------------------------------------------------
|1 |0.0, 1.241112828438949, 1.0, ...
+------+-----------------------------------------------------------------------
|1 |0.0, 1.01, 1.0212464589235128, ...
+------+-----------------------------------------------------------------------
|1 |0.0, 0.0, 1.0303994011640099, ...
+------+-----------------------------------------------------------------------
|1 |0.0, 1.0310714270488266, 0.0, ...
+------+-----------------------------------------------------------------------
|1 |0.0, 1.7106598984771573, 0.0, ...
+------+-----------------------------------------------------------------------
|1 |0.0, 1.0, 1.7657142857142856, ...
+------+-----------------------------------------------------------------------
|1 |0.0, 1.3483709273182958, 1.7071428571428573, ...
+------+-----------------------------------------------------------------------
|1 |0.0, 1.4608788853161845, 1.2461538461538462, ...
+------+-----------------------------------------------------------------------
|1 |0.0, 1.0, 0.0, ...
+------+-----------------------------------------------------------------------
|1 |0.0, 1.0, 0.0, ...
+------+-----------------------------------------------------------------------
|1 |1.6600496277915633, 1.0, 1.0, ...
+------+-----------------------------------------------------------------------
|1 |1.3537936913895994, 1.812121212121212, 1.2403100775193798, ...
+------+-----------------------------------------------------------------------
|1 |0.0, 1.6721590909090909, 1.0, ...
+------+-----------------------------------------------------------------------
|1 |1.6479591836734695, 0.0, 0.0, ...
+------+-----------------------------------------------------------------------
|1 |0.0, 1.075, 0.0, ...
+------+-----------------------------------------------------------------------
|1 |1.2246376811594204, 0.0, 0.0, ...
+------+-----------------------------------------------------------------------
|1 |1.0, 1.659994867847062, 1.0, ...
+------+-----------------------------------------------------------------------
|1 |1.0, 0.0, 1.5507936E9, ...
+------+-----------------------------------------------------------------------
|1 |0.0, 1.0, 0.0, ...
+------+-----------------------------------------------------------------------
|1 |1.6974358974358972, 0.0, 0.0, ...|
+------+-----------------------------------------------------------------------
|2 |0.0, 0.0, 1.4455958549222798, ... (268 total)
+------+-----------------------------------------------------------------------
|2 |0.0, 1.02875, 1.0, ...
+------+-----------------------------------------------------------------------
|2 |0.0, 1.0714285714285714, 1.0, ...
+------+-----------------------------------------------------------------------
|2 |0.0, 1.2179289026275115, 1.0, ...
+------+-----------------------------------------------------------------------
|2 |0.0, 1.01, 1.0191218130311614, ...
+------+-----------------------------------------------------------------------
|2 |0.0, 0.0, 1.028490828331661, ...
+------+-----------------------------------------------------------------------
|2 |0.0, 1.028214284187194, 0.0, ...
+------+-----------------------------------------------------------------------
|2 |0.0, 1.7309644670050761, 0.0, ...
+------+-----------------------------------------------------------------------
|2 |0.0, 1.0, 1.7885714285714287, ...
+------+-----------------------------------------------------------------------
|2 |0.0, 1.3525480367585632, 1.7285714285714286, ...
+------+-----------------------------------------------------------------------
|2 |0.0, 1.4683815648445875, 1.2153846153846155, ...
+------+-----------------------------------------------------------------------
|2 |0.0, 1.0, 0.0, ...
+------+-----------------------------------------------------------------------
|2 |0.0, 1.0, 0.0, ...
+------+-----------------------------------------------------------------------
|2 |1.6972704714640199, 1.0, 1.0, ...
+------+-----------------------------------------------------------------------
|2 |1.3580562659846547, 1.8242424242424242, 1.2170542635658914, ...
+------+-----------------------------------------------------------------------
|2 |0.0, 1.6971590909090908, 1.0, ...
+------+-----------------------------------------------------------------------
|2 |1.663265306122449, 0.0, 0.0, ...
+------+-----------------------------------------------------------------------
|2 |0.0, 1.0964285714285715, 0.0, ...
+------+-----------------------------------------------------------------------
|2 |1.2028985507246377, 0.0, 0.0, ...
+------+-----------------------------------------------------------------------
|2 |1.0, 1.6782140107775212, 1.0, ...
+------+-----------------------------------------------------------------------
|2 |1.0, 0.0, 1.5507936E9, ...
+------+-----------------------------------------------------------------------
|2 |0.0, 1.0, 0.0, ...
+------+-----------------------------------------------------------------------
|2 |1.7282051282051283, 0.0, 0.0, ...|
+------+-----------------------------------------------------------------------
+------+-----------------------------------------------------------------------
only showing top 2 rows
函数,我也能够“解决”我的问题,但我必须:
1)将explode
列强制转换为字符串:
windows
2)将字符串解析为字符串数组(每个字符串代表一个向量)
def stringify_matrices(x):
arr = x.toArray()
l = arr.tolist()
return l
stringify_matrices_udf = udf(lambda y: stringify_matrices(y),)
expanded = \
extracted.withColumn('expanded',
stringify_matrices_udf('windows')
)
3)def parse_matrices(x):
from ast import literal_eval
t = literal_eval(str(x))
str_arr = [str(a) for a in t]
return str_arr
parse_matrices_udf = udf(lambda y: parse_matrices(y), ArrayType(StringType()))
parsed = \
expanded.withColumn('parsed',
parse_matrices_udf('expanded')
)
explode
4)投射到parsed = parsed.withColumn('exploded', explode(parsed.parsed)).select('groups', 'exploded')
ArrayType(DoubleType()))
以上方法有效,但我觉得似乎有更好的方法来解决这个问题。
EDIT_2 @mayanak农业 感谢您的回答! 我想作为回应,我会问:
如何从def convert_to_double(x):
str_arr = x.replace('[','').replace(']','').split(',')
flt_arr = [float(a) for a in str_arr]
return flt_arr
convert_to_double_udf = udf(lambda y: convert_to_double(y), ArrayType(DoubleType()))
converted = parsed.withColumn('feature_vector', convert_to_double_udf('exploded'))
列进行转换:
例如
DenseMatrix
dm_df = sqlContext.createDataFrame([
(1,
DenseMatrix(numRows=3, numCols=4, values=[2,4,2,5,30,4,2,5,30,4,2,5], isTransposed=True)),
(2,
DenseMatrix(numRows=2, numCols=4, values=[2,1,3,7,2,4,2,9], isTransposed=True)),
(3,
DenseMatrix(numRows=4, numCols=4, values=[2,4,2,5,2,4,2,5,2,1,3,7,2,1,3,7], isTransposed=True))],
['groups', 'windows'])
dm_df.show()
到一列2D浮点数(如您的示例所示):
+------+-----------------------------------------------------------------------------------+
|groups|windows |
+------+-----------------------------------------------------------------------------------+
|1 |2.0 4.0 2.0 5.0
30.0 4.0 2.0 5.0
30.0 4.0 2.0 5.0 |
|2 |2.0 1.0 3.0 7.0
2.0 4.0 2.0 9.0 |
|3 |2.0 4.0 2.0 5.0
2.0 4.0 2.0 5.0
2.0 1.0 3.0 7.0
2.0 1.0 3.0 7.0 |
+------+-----------------------------------------------------------------------------------+
arr_df = sqlContext.createDataFrame([
(1, [[2,4,2,5],[30,4,2,5],[30,4,2,5]]),
(2, [[2,1,3,7],[2,4,2,9]]),
(3, [[2,4,2,5],[2,4,2,5],[2,1,3,7],[2,1,3,7]])],
['groups', 'windows'])
arr_df.show()
再次感谢!
答案 0 :(得分:0)
我无法创建您的确切示例数据框。因此,我创建了它的较小版本。让我知道是否需要任何更改。
import pyspark.sql.functions as F
df = sql.createDataFrame([
(1, [[2,4,2,5],[30,4,2,5],[30,4,2,5]]),
(2, [[2,1,3,7],[2,4,2,9]]),
(3, [[2,4,2,5,3],[2,4,2,5],[2,1,3,7],[2,1,3,7]])],
['groups', 'windows'])
通过在'windows'
列上爆炸,我们得到所需的结果。
df = df.select(['groups', F.explode(F.col('windows')).alias('windows')])
这给出了输出,
+------+---------------+
|groups| windows|
+------+---------------+
| 1| [2, 4, 2, 5]|
| 1| [30, 4, 2, 5]|
| 1| [30, 4, 2, 5]|
| 2| [2, 1, 3, 7]|
| 2| [2, 4, 2, 9]|
| 3|[2, 4, 2, 5, 3]|
| 3| [2, 4, 2, 5]|
| 3| [2, 1, 3, 7]|
| 3| [2, 1, 3, 7]|
+------+---------------+
编辑:
将其转换为列表后,我直接能够将其爆炸。无需转换为字符串。只需在stringify_matrices_udf
中输入数据类型即可。
import pyspark.sql.functions as F
from pyspark.sql.types import *
def stringify_matrices(x):
arr = x.toArray()
l = arr.tolist()
print l
return l
df = sql.createDataFrame([
(1,
DenseMatrix(numRows=3, numCols=4, values=[2,4,2,5,30,4,2,5,30,4,2,5], isTransposed=True)),
(2,
DenseMatrix(numRows=2, numCols=4, values=[2,1,3,7,2,4,2,9], isTransposed=True)),
(3,
DenseMatrix(numRows=4, numCols=4, values=[2,4,2,5,2,4,2,5,2,1,3,7,2,1,3,7], isTransposed=True))],
['groups', 'windows'])
stringify_matrices_udf = F.udf(lambda y: stringify_matrices(y),ArrayType(ArrayType(FloatType())))
df = \
df.withColumn('expanded',
stringify_matrices_udf('windows')
) \
.select(['groups', F.explode(F.col('expanded')).alias('windows')])
df.show()
这给了
+------+--------------------+
|groups| windows|
+------+--------------------+
| 1|[2.0, 4.0, 2.0, 5.0]|
| 1|[30.0, 4.0, 2.0, ...|
| 1|[30.0, 4.0, 2.0, ...|
| 2|[2.0, 1.0, 3.0, 7.0]|
| 2|[2.0, 4.0, 2.0, 9.0]|
| 3|[2.0, 4.0, 2.0, 5.0]|
| 3|[2.0, 4.0, 2.0, 5.0]|
| 3|[2.0, 1.0, 3.0, 7.0]|
| 3|[2.0, 1.0, 3.0, 7.0]|
+------+--------------------+