sklearn TruncatedSVD和MLLib SVD之间的值不同?

时间:2018-02-19 16:29:53

标签: scala scikit-learn apache-spark-mllib svd

我在项目中使用降维。我在Python中开始了一个概念验证,然后继续scala。

我在python中使用了TruncatedSVD,在scala中使用了SVD。我的数学实际上与SVD无关,但我测试了一些事情,并对结果感到满意。

我现在注意到一些我无法解释的事情,这让我质疑我的scala代码。

我在相同的数据上尝试了TruncatedSVD(python)和SVD(scala),目标是输出2D,我得到不同的值。

任何人都可以向我解释为什么我没有得到相同的结果?这是代码,无论是在python中,还是在scala中。

的Python:


    from sklearn.decomposition import TruncatedSVD
    import numpy as np
    l = [np.array([0.1, 0.2, 0.2, 0.4]), 
    np.array([0.1, 0.2, 0.2, 0.4]), 
    np.array([0.25, 0.33, 0.54, 0.87]), 
    np.array([0.12, 0.89, 0.12, 0.35])]

    data2D = TruncatedSVD(n_components=2).fit_transform(l)
    print(data2D)
    array([
       [ 0.49114514, -0.08966268],
       [ 0.49114514, -0.08966268],
       [ 1.0559638 , -0.32304442],
       [ 0.81691343,  0.5253898 ]])

Scala的:


    import org.apache.spark.mllib.linalg.{Matrix, SingularValueDecomposition, Vector, Vectors}
    import org.apache.spark.mllib.linalg.distributed.RowMatrix


    val vectors =  sc.parallelize(Array(
    Vectors.dense(Array(0.1, 0.2, 0.2, 0.4)),
    Vectors.dense(Array(0.1, 0.2, 0.2, 0.4)), 
    Vectors.dense(Array(0.25, 0.33, 0.54, 0.87)), 
    Vectors.dense(Array(0.12, 0.89, 0.12, 0.35))))

    val matrix : RowMatrix = new RowMatrix(vectors)
    val svd : SingularValueDecomposition[RowMatrix, Matrix] = matrix.computeSVD(2, computeU = true)
    val data2D = svd.U.rows
    data2D.foreach(println)

    [-0.32635459035606795,0.14239870001655083]
    [-0.32635459035606795,0.14239870001655083]
    [-0.7016635327093717,0.5130463060403057]
    [-0.542820089507427,-0.8344032048869356]

我看到相同的两个原始矢量具有相同的2D表示,但这就是它。我用错了吗?我错过了什么?

非常感谢

JT

编辑:

非平方矩阵的结果,其中我在每个向量的末尾添加了0.18:

的Python:

    array([[ 0.51885074, -0.08919845],
       [ 0.51885074, -0.08919845],
       [ 1.06938646, -0.32415061],
       [ 0.83672766,  0.52490632]])

Scala(带有s和V对象)

svd.U.rows:
(-0.19794364441758672  0.09051048659772198,
-0.5478856677155637   -0.818428534757707,
-0.37169369848891953  0.3725778323351847,
-0.6877514437772323   0.42786133587825526,
-0.22244404876848498  -0.010149946949985795)

svd.V:
(-0.19794364441758672  0.09051048659772198,
-0.5478856677155637   -0.818428534757707,
-0.37169369848891953  0.3725778323351847,
-0.6877514437772323   0.42786133587825526,
-0.22244404876848498  -0.010149946949985795) 

svd.s:
[1.5434094639966702,0.6296927732607006]

0 个答案:

没有答案