我发现sklearn.decomposition.TruncatedSVD和scipy.sparse.linalg.svds的文档都提到它们都为稀疏矩阵执行SVD
。他们之间有什么区别?
感谢。
答案 0 :(得分:11)
TruncatedSVD
功能更丰富。它具有scikit-learn API,因此您可以将其放在sklearn.Pipeline
对象中并在新矩阵上调用transform
,而不必自己计算矩阵乘法。它提供两种算法:快速随机SVD求解器(默认)或scipy.sparse.svds
。
(完全披露:我写了TruncatedSVD
。)
答案 1 :(得分:1)
也有做
#check how to use TRuncatedSVD
X=[[1,2,3],[1,4,2],[4,1,7],[5,6,8]]
# TRUNCATED SVD
from sklearn.decomposition import TruncatedSVD
svd = TruncatedSVD(n_components=2, n_iter=7, random_state=42)
US=svd.fit_transform(X)
V=svd.components_
S=svd.singular_values_
print('u,s,v', US,S,V)
print('X_restored dot way',np.round(np.dot(US,V),1),'svdinverse way',np.round(svd.inverse_transform(U),1))
# LINALG SVD
U1,S1,V1=np.linalg.svd(X)
print('u1,s1,v1 remark negative mirrored',U1[:,:2]*S1[:2],V1[:2,:])
print('X restored u1,s1,v1, 2 components',np.round( np.dot( U1[:,:2]*S1[:2],V1[:2,:] ),1 ) )
# sparse svd
from scipy.sparse import csc_matrix
from scipy.sparse.linalg import svds, eigs
A = csc_matrix(X, dtype=float)
u2, s2, vt2 = svds(A, k=2)
print('sparse reverses !',u2*s2,vt2)
print('x restored',np.round( np.dot(u2*s2,vt2),1) )
结果
u,s,v [[ 3.66997034 -0.34754761]
[ 3.82764223 -2.51681397]
[ 7.61154768 2.83860088]
[11.13470337 -0.96070751]] [14.49264657 3.92883644] [[ 0.44571865 0.46215842 0.76664495]
[ 0.23882889 -0.88677195 0.39572247]]
X_restored dot way
[[1.6 2. 2.7]
[1.1 4. 1.9]
[4.1 1. 7. ]
[4.7 6. 8.2]]
svdinverse way
[[1.6 2. 2.7]
[1.1 4. 1.9]
[4.1 1. 7. ]
[4.7 6. 8.2]]
u1,s1,v1 remark negative mirrored
[[ -3.66997034 0.34754761]
[ -3.82764223 2.51681397]
[ -7.61154768 -2.83860088]
[-11.13470337 0.96070751]] [[-0.44571865 -0.46215842 -0.76664495]
[-0.23882889 0.88677195 -0.39572247]]
X restored u1,s1,v1, 2 components
[[1.6 2. 2.7]
[1.1 4. 1.9]
[4.1 1. 7. ]
[4.7 6. 8.2]]
sparse reverses !
[[-0.34754761 3.66997034]
[-2.51681397 3.82764223]
[ 2.83860088 7.61154768]
[-0.96070751 11.13470337]]
[[ 0.23882889 -0.88677195 0.39572247]
[ 0.44571865 0.46215842 0.76664495]]
x restored
[[1.6 2. 2.7]
[1.1 4. 1.9]
[4.1 1. 7. ]
[4.7 6. 8.2]]
[[-0.25322982 0.0884607 0.88223679]
[-0.26410926 0.64060034 0.16752502]
[-0.52520067 -0.72250421 0.11259767]
[-0.76830021 0.24452723 -0.42534148]]
[14.49264657 3.92883644 0.72625043]
[[-0.44571865 -0.46215842 -0.76664495]
[-0.23882889 0.88677195 -0.39572247]
[-0.86272571 -0.00671608 0.50562757]]
[[1. 2. 3.]
[1. 4. 2.]
[4. 1. 7.]
[5. 6. 8.]]