多类SVM无法使用20个新闻组数据集

时间:2018-12-22 11:56:17

标签: python svm multiclass-classification

我正在尝试使用Mblondel Multiclass SVM中的多类SVM代码,我读了他的论文,他使用了sklearn 20newsgroup中的数据集,但是当我尝试使用它时,代码无法正常工作。

我试图更改代码以匹配20newsgroup数据集。但是我被这个错误困住了。

  

回溯(最近通话最近一次):

     

文件

中的文件“ F:\ env \ chatbotstripped \ CSSVM.py”,第157行      

clf.fit(X,y)

     

适合的文件“ F:\ env \ chatbotstripped \ CSSVM.py”,第106行

     

v = self._violation(g,y,i)

     

_violation中的文件“ F:\ env \ chatbotstripped \ CSSVM.py”,第50行

     

elif k!= y [i]和self.dual_coef_ [k,i]> = 0:

     

IndexError:索引20超出了大小为20的轴0的边界

这是主要代码:

   java.lang.NullPointerException: Attempt to invoke virtual method 'void com.example.hp.finalproject.Activities.Wallpaper.WallpaperHelper.setLiked(long)' on a null object reference
        at com.example.hp.finalproject.Activities.Wallpaper.wallAdapter$ViewHolderWall$1.onClick(wallAdapter.java:112)
        at com.sackcentury.shinebuttonlib.ShineButton$OnButtonClickListener.onClick(ShineButton.java:343)
        at android.view.View.performClick(View.java:5637)
        at android.view.View$PerformClick.run(View.java:22429)
        at android.os.Handler.handleCallback(Handler.java:751)
        at android.os.Handler.dispatchMessage(Handler.java:95)
        at android.os.Looper.loop(Looper.java:154)
        at android.app.ActivityThread.main(ActivityThread.java:6119)
        at java.lang.reflect.Method.invoke(Native Method)
        at com.android.internal.os.ZygoteInit$MethodAndArgsCaller.run(ZygoteInit.java:886)
        at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:776)

这是合适的代码:

from sklearn.datasets import fetch_20newsgroups
news_train = fetch_20newsgroups(subset='train')
X, y = news_train.data[:100], news_train.target[:100]

clf = MulticlassSVM(C=0.1, tol=0.01, max_iter=100, random_state=0, verbose=1)
X = TfidfVectorizer().fit_transform(X)
clf.fit(X, y)
print(clf.score(X, y))

和_violation代码:

def fit(self, X, y):
    n_samples, n_features = X.shape

    self._label_encoder = LabelEncoder()
    y = self._label_encoder.fit_transform(y)

    n_classes = len(self._label_encoder.classes_)
    self.dual_coef_ = np.zeros((n_classes, n_samples), dtype=np.float64)
    self.coef_ = np.zeros((n_classes, n_features))

    norms = np.sqrt(np.sum(X.power(2), axis=1)) # i changed this code

    rs = check_random_state(self.random_state)
    ind = np.arange(n_samples)
    rs.shuffle(ind)

    # i added this sparse
    sparse = sp.isspmatrix(X)
    if sparse:
        X = np.asarray(X.data, dtype=np.float64, order='C')

    for it in range(self.max_iter):
        violation_sum = 0
        for ii in range(n_samples):
            i = ind[ii]

            if norms[i] == 0:
                continue

            g = self._partial_gradient(X, y, i)
            v = self._violation(g, y, i)
            violation_sum += v

            if v < 1e-12:
                continue

            delta = self._solve_subproblem(g, y, norms, i)
            self.coef_ += (delta * X[i][:, np.newaxis]).T
            self.dual_coef_[:, i] += delta

        if it == 0:
            violation_init = violation_sum

        vratio = violation_sum / violation_init

        if self.verbose >= 1:
            print("iter", it + 1, "violation", vratio)

        if vratio < self.tol:
            if self.verbose >= 1:
                print("Converged")
            break
    return self

我知道索引有问题,我不确定如何解决该问题,并且我不想破坏代码,因为我不太了解该代码的工作原理。

1 个答案:

答案 0 :(得分:0)

您必须将tfidf矢量化器的稀疏矩阵输出转换为密集矩阵,然后使其成为2D数组。试试吧!

from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
news_train = fetch_20newsgroups(subset='train')
text, y = news_train.data[:1000], news_train.target[:1000]

clf = MulticlassSVM(C=0.1, tol=0.01, max_iter=100, random_state=0, verbose=1)
vectorizer= TfidfVectorizer(min_df=20,stop_words='english')
X = np.asarray(vectorizer.fit_transform(text).todense())
clf.fit(X, y)
print(clf.score(X, y))

输出:

iter 1 violation 1.0
iter 2 violation 0.07075102408683964
iter 3 violation 0.018288133735158228
iter 4 violation 0.009149083942255389
Converged
0.953