SMOTE算法初始条件

时间:2016-12-18 05:30:59

标签: python machine-learning classification imblearn

我正在使用python imbalanced-learn包中的SMOTE算法:

from imblearn.over_sampling import SMOTE
sm = SMOTE(kind='regular', n_neighbors = 4)
 :
X_train_resampled, y_train_resampled = sm.fit_sample(X_train, y_train)

我已明确设置n_neighbors = 4。但是,我从上面的代码中得到以下错误:

ValueError                                Traceback (most recent call last)
<ipython-input-2-9e9116d71706> in <module>()
     33 
     34     #try:
---> 35     X_train_resampled, y_train_resampled = sm.fit_sample(X_train, y_train)
     36     #except:
     37     #continue

/usr/local/lib/python3.4/dist-packages/imblearn/base.py in fit_sample(self, X, y)
    176         """
    177 
--> 178         return self.fit(X, y).sample(X, y)
    179 
    180     def _validate_ratio(self):

/usr/local/lib/python3.4/dist-packages/imblearn/base.py in sample(self, X, y)
    153             self._validate_ratio()
    154 
--> 155         return self._sample(X, y)
    156 
    157     def fit_sample(self, X, y):

/usr/local/lib/python3.4/dist-packages/imblearn/over_sampling/smote.py in _sample(self, X, y)
    287             nns = self.nearest_neighbour.kneighbors(
    288                 X_min,
--> 289                 return_distance=False)[:, 1:]
    290 
    291             self.logger.debug('Create synthetic samples ...')

/usr/local/lib/python3.4/dist-packages/sklearn/neighbors/base.py in kneighbors(self, X, n_neighbors, return_distance)
    341                 "Expected n_neighbors <= n_samples, "
    342                 " but n_samples = %d, n_neighbors = %d" %
--> 343                 (train_size, n_neighbors)
    344             )
    345         n_samples, _ = X.shape

ValueError: Expected n_neighbors <= n_samples,  but n_samples = 5, n_neighbors = 6

知道为什么n_neighbors = 4的设置不起作用?

1 个答案:

答案 0 :(得分:0)

正确的参数是:

  

k_neighbors :int或object,optional(default = 5)

     

如果是int,用于构造合成样本的最近邻居数。 if object,一个继承自sklearn.neighbors.base.KNeighborsMixin的估计器,用于查找k_neighbors。

您通过 n 通知 n_neighbors ,但正确的是 k_neighbors ,并带有k!

该消息是因为 5 默认

阅读文档here