OneVsRestClassifier无法预测二值化多类标签

时间:2019-05-30 21:14:30

标签: python classification random-forest cross-validation

我已经用randorm fores分类器执行OneVsRestClassifier来对类标签进行二值化。我使用Kfold交叉验证。我进行了二值化课程以绘制roc曲线

我的问题是classifier.predict方法无法在某些记录中进行预测。 这是我的代码

# Define classifier
classifier = OneVsRestClassifier(RandomForestClassifier(bootstrap=False, class_weight=None, criterion='gini',
            max_depth=13, max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=4, min_samples_split=10,
            min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=3,
            oob_score=False, random_state=27, verbose=0, warm_start=False))

# Define Method
kf = KFold(n_splits=3, random_state=27, shuffle=True)

scores = []
y_preds = []
y_trues = []
y_probs = []
for train_indices, test_indices in kf.split(X):
    # Perform Fold Validation
    classifier.fit(X[train_indices], y[train_indices])
    y_pred = classifier.predict(X[test_indices])
    y_true = y[test_indices]
    score = classifier.score(X[test_indices], y_true)
    y_prob = classifier.predict_proba(X[test_indices])

    # Append to list
    scores.append(score)
    y_preds.append(y_pred)
    y_trues.append(y_true)
    y_probs.append(y_prob)

    print('iteration..')

scores = np.array(scores)
ypred = np.concatenate(np.array(y_preds))
ytrue = np.concatenate(np.array(y_trues))
yprob = np.concatenate(np.array(y_probs))

我的ytrue,ypred和yprob结果

ytrue 
 [[0 0 0 0 0 1]
 [0 0 0 0 0 1]
 [0 0 0 0 0 1]
 ...
 [1 0 0 0 0 0]
 [0 0 0 0 1 0]
 [0 0 1 0 0 0]]
ypred 
 [[0 0 0 0 0 0]
 [0 0 0 1 0 0]
 [0 0 0 0 0 1]
 ...
 [0 0 0 0 0 0]
 [0 0 0 1 0 0]
 [0 0 0 0 0 0]]
yprob 
 [[0.0284438 0.        0.0237302 0.4376503 0.070079  0.3654891]
 [0.018646  0.        0.0111984 0.5008399 0.1122584 0.3558225]
 [0.0770742 0.        0.1195185 0.0839903 0.1776736 0.6168349]
 ...
 [0.0319885 0.0194651 0.2914288 0.1826586 0.3700613 0.1913803]
 [0.0037897 0.0053333 0.0129757 0.552535  0.2664548 0.1772928]
 [0.0030278 0.0026374 0.0673323 0.4866199 0.2321866 0.260554 ]]

为什么某些记录/行没有预测的班级?我的代码有什么问题?

0 个答案:

没有答案