我在项目中使用降维。我在Python中开始了一个概念验证,然后继续scala。
我在python中使用了TruncatedSVD,在scala中使用了SVD。我的数学实际上与SVD无关,但我测试了一些事情,并对结果感到满意。
我现在注意到一些我无法解释的事情,这让我质疑我的scala代码。
我在相同的数据上尝试了TruncatedSVD(python)和SVD(scala),目标是输出2D,我得到不同的值。
任何人都可以向我解释为什么我没有得到相同的结果?这是代码,无论是在python中,还是在scala中。
的Python:
from sklearn.decomposition import TruncatedSVD
import numpy as np
l = [np.array([0.1, 0.2, 0.2, 0.4]),
np.array([0.1, 0.2, 0.2, 0.4]),
np.array([0.25, 0.33, 0.54, 0.87]),
np.array([0.12, 0.89, 0.12, 0.35])]
data2D = TruncatedSVD(n_components=2).fit_transform(l)
print(data2D)
array([
[ 0.49114514, -0.08966268],
[ 0.49114514, -0.08966268],
[ 1.0559638 , -0.32304442],
[ 0.81691343, 0.5253898 ]])
Scala的:
import org.apache.spark.mllib.linalg.{Matrix, SingularValueDecomposition, Vector, Vectors}
import org.apache.spark.mllib.linalg.distributed.RowMatrix
val vectors = sc.parallelize(Array(
Vectors.dense(Array(0.1, 0.2, 0.2, 0.4)),
Vectors.dense(Array(0.1, 0.2, 0.2, 0.4)),
Vectors.dense(Array(0.25, 0.33, 0.54, 0.87)),
Vectors.dense(Array(0.12, 0.89, 0.12, 0.35))))
val matrix : RowMatrix = new RowMatrix(vectors)
val svd : SingularValueDecomposition[RowMatrix, Matrix] = matrix.computeSVD(2, computeU = true)
val data2D = svd.U.rows
data2D.foreach(println)
[-0.32635459035606795,0.14239870001655083]
[-0.32635459035606795,0.14239870001655083]
[-0.7016635327093717,0.5130463060403057]
[-0.542820089507427,-0.8344032048869356]
我看到相同的两个原始矢量具有相同的2D表示,但这就是它。我用错了吗?我错过了什么?
非常感谢
JT
编辑:
非平方矩阵的结果,其中我在每个向量的末尾添加了0.18:
的Python:
array([[ 0.51885074, -0.08919845],
[ 0.51885074, -0.08919845],
[ 1.06938646, -0.32415061],
[ 0.83672766, 0.52490632]])
Scala(带有s和V对象)
svd.U.rows:
(-0.19794364441758672 0.09051048659772198,
-0.5478856677155637 -0.818428534757707,
-0.37169369848891953 0.3725778323351847,
-0.6877514437772323 0.42786133587825526,
-0.22244404876848498 -0.010149946949985795)
svd.V:
(-0.19794364441758672 0.09051048659772198,
-0.5478856677155637 -0.818428534757707,
-0.37169369848891953 0.3725778323351847,
-0.6877514437772323 0.42786133587825526,
-0.22244404876848498 -0.010149946949985795)
svd.s:
[1.5434094639966702,0.6296927732607006]