如何将预测概率限制为一类

时间:2013-08-21 17:34:01

标签: machine-learning scikit-learn

使用此类内容时

clf = KNeighborsClassifier(n_neighbors=3)
clf.fit(X,y)
predictions = clf.predict_proba(X_test)

如何将预测仅限制在一个班级?出于性能原因需要这样做,例如,当我有数千个类时,但只关心一个特定类是否具有高概率。

1 个答案:

答案 0 :(得分:1)

Sklearn没有实现它,你必须编写某种包装器,例如 - 你可以extend KNeighborsClassifier类并重载predict_proba方法。

根据source code

 def predict_proba(self, X):
        """Return probability estimates for the test data X.

        Parameters
        ----------
        X : array, shape = (n_samples, n_features)
            A 2-D array representing the test points.

        Returns
        -------
        p : array of shape = [n_samples, n_classes], or a list of n_outputs
            of such arrays if n_outputs > 1.
            The class probabilities of the input samples. Classes are ordered
            by lexicographic order.
        """
        X = atleast2d_or_csr(X)

        neigh_dist, neigh_ind = self.kneighbors(X)

        classes_ = self.classes_
        _y = self._y
        if not self.outputs_2d_:
            _y = self._y.reshape((-1, 1))
            classes_ = [self.classes_]

        n_samples = X.shape[0]

        weights = _get_weights(neigh_dist, self.weights)
        if weights is None:
            weights = np.ones_like(neigh_ind)

        all_rows = np.arange(X.shape[0])
        probabilities = []
        for k, classes_k in enumerate(classes_):
            pred_labels = _y[:, k][neigh_ind]
            proba_k = np.zeros((n_samples, classes_k.size))

            # a simple ':' index doesn't work right
            for i, idx in enumerate(pred_labels.T):  # loop is O(n_neighbors)
                proba_k[all_rows, idx] += weights[:, i]

            # normalize 'votes' into real [0,1] probabilities
            normalizer = proba_k.sum(axis=1)[:, np.newaxis]
            normalizer[normalizer == 0.0] = 1.0
            proba_k /= normalizer

            probabilities.append(proba_k)

        if not self.outputs_2d_:
            probabilities = probabilities[0]

        return probabilities

只需修改代码,以便将for k, classes_k in enumerate(classes_):循环更改为您需要的一个特定类的包含。

一种人工方法是覆盖classes_变量,使其成为被认为是类的单例,并在完成后将其还原。