给定一列具有NaN条目的密集向量,我想计算列之间的相关性。有没有办法在不拆卸矢量值进行清理的情况下做到这一点?
#pyspark
from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors
from pyspark.mllib.linalg import Vectors as MlVectors # (
from pyspark.mllib.stat import Statistics
def get_data():
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(
[
(Vectors.dense(1., 3., 2.), 0),
(Vectors.dense(None, 4., 1.), 1),
(Vectors.dense(3., None, 0.), 2),
(Vectors.dense(4., 12., None), 3),
(Vectors.dense(5., 0., 1.), 5),
(Vectors.dense(6., -1., 0.), 6)], ["features", "foo"])
return df
def correlation(df):
digestible_data = df.select("features").rdd.map(lambda row: MlVectors.dense(row[0]))
print(Statistics.corr(digestible_data))
if __name__ == '__main__':
correlation(get_data())
# OUTPUT:
# [[ 1. nan nan]
# [ nan 1. nan]
# [ nan nan 1.]]
答案 0 :(得分:0)
我看到没有人愿意深入研究这个问题。所以,这是一个尽可能慢的解决方案:
from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors
from pyspark.mllib.linalg import Vectors as MlVectors # (
from pyspark.mllib.stat import Statistics
import numpy as np
def get_data():
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(
[
(Vectors.dense(1., 3., 2.), 0),
(Vectors.dense(None, 4., 1.), 1),
(Vectors.dense(3., None, 0.), 2),
(Vectors.dense(4., 12., None), 3),
(Vectors.dense(5., 0., 1.), 5),
(Vectors.dense(6., -1., 0.), 6)], ["features", "foo"])
return df
def correlation(df):
digestible_data = df.select("features").rdd.map(lambda row: MlVectors.dense(row[0]))
print(Statistics.corr(digestible_data))
def nullproofed_correlation(df, column='features'):
num_colls = len(df.head()[column])
res = np.ones((num_colls, num_colls), dtype=np.float32)
for i in range(1, num_colls):
for j in range(i):
feature_pair_df = df.select("features").rdd.map(lambda x: MlVectors.dense([x[0][i], x[0][j]]))
feature_pair_df = feature_pair_df.filter(lambda x: not np.isnan(x[0]) and not np.isnan(x[1]))
corr_matrix = Statistics.corr(feature_pair_df, method="pearson")
corr = corr_matrix[0, 1]
res[i, j], res[j, i] = corr, corr
print(res)
return res
if __name__ == '__main__':
print(correlation(get_data()))
print(nullproofed_correlation(get_data()))
一般来说,只能根据现有数据计算相关性。因此,创建一个指示值是否存在的新列是有意义的,然后仅计算当前数据。并使用" presense"信息作为其他地方的附加功能。不幸的是,在处理稀疏数据时,火花相关性没有任何帮助。