交叉验证和过采样(SMOTE)功能

时间:2019-05-15 12:32:53

标签: python cross-validation oversampling

我写了下面的代码。 X是形状为(1000,5)的数据帧,而y是形状为(1000,1)的数据帧。 y是要预测的目标数据,并且不平衡。我想应用交叉验证和SMOTE。

def Learning(n, est, X, y):
    s_k_fold = StratifiedKFold(n_splits = n)
    acc_scores = []
    rec_scores = []
    f1_scores = []

    for train_index, test_index in s_k_fold.split(X, y): 
        X_train = X[train_index]
        y_train = y[train_index]    

        sm = SMOTE(random_state=42)
        X_resampled, y_resampled = sm.fit_resample(X_train, y_train)

        X_test = X[test_index]
        y_test = y[test_index]

        est.fit(X_resampled, y_resampled)
        y_pred = est.predict(X_test)
        acc_scores.append(accuracy_score(y_test, y_pred))
        rec_scores.append(recall_score(y_test, y_pred))
        f1_scores.append(f1_score(y_test, y_pred)) 

    print('Accuracy:',np.mean(acc_scores))
    print('Recall:',np.mean(rec_scores))
    print('F1:',np.mean(f1_scores)) 

Learning(3, SGDClassifier(), X_train_s_pca, y_train)

运行代码时,出现以下错误:

  

[Int64Index([4231,4235,4246,4250,4255,4295,4317,   4344,4381,\ n 4387,\ n ... \ n 13122,   13123、13124、13125、13126、13127、13128、13129、13130,\ n
  13131],\ n dtype ='int64',length = 8754)]在[列]“中

感谢帮助使其运行。

1 个答案:

答案 0 :(得分:0)

如果您仔细观察错误堆栈跟踪(这很重要,但没有包括在内),则应该看到错误来自这些行(并将来自其他类似行):

X_train = X[train_index]

这种选择仅适用于Numpy数组的行的方式。由于您使用的是Pandas DataFrame,因此应使用loc

X_train = X.loc[train_index]

或者,您也可以使用values将DataFrame转换为Numpy数组(以最大程度地减少代码更改):

Learning(3, SGDClassifier(), X_train_s_pca.values, y_train.values)