加速预测

时间:2015-06-29 06:35:14

标签: scikit-learn

使用5个特征创建的SVM模型和使用默认参数的3000个样本进行的预测使用5个特征和100000个样本花费了非常长的时间(超过一小时)。有加速预测的方法吗?

3 个答案:

答案 0 :(得分:2)

这里要考虑的几个问题:

  1. 您是否标准化了输入矩阵X? SVM不是规模不变的,因此如果算法在没有适当缩放的情况下获取大量原始输入,那么算法可能难以进行分类。

  2. 参数C的选择:较高的C允许更复杂的非平滑决策边界,并且需要更多时间来适应这种复杂性。因此,将值C从默认值1降低到较低值可能会加快此过程。

  3. 还建议您选择gamma的正确值。这可以通过网格搜索 - 交叉验证来完成。

  4. 以下是进行网格搜索交叉验证的代码。为简单起见,我忽略了测试集。

    import numpy as np
    from sklearn.datasets import make_classification
    from sklearn.preprocessing import StandardScaler
    from sklearn.svm import SVC
    from sklearn.pipeline import make_pipeline
    from sklearn.grid_search import GridSearchCV
    from sklearn.metrics import accuracy_score, recall_score, f1_score, roc_auc_score, make_scorer
    
    # generate some artificial data
    X, y = make_classification(n_samples=3000, n_features=5, weights=[0.1, 0.9])
    
    # make a pipeline for convenience
    pipe = make_pipeline(StandardScaler(), SVC(kernel='rbf', class_weight='auto'))
    
    # set up parameter space, we want to tune SVC params C and gamma
    # the range below is 10^(-5) to 1 for C and 0.01 to 100 for gamma
    param_space = dict(svc__C=np.logspace(-5,0,5), svc__gamma=np.logspace(-2, 2, 10))
    
    # choose your customized scoring function, popular choices are f1_score, accuracy_score, recall_score, roc_auc_score
    my_scorer = make_scorer(roc_auc_score, greater_is_better=True)
    # construct grid search
    gscv = GridSearchCV(pipe, param_space, scoring=my_scorer)
    gscv.fit(X, y)
    # what's the best estimator
    gscv.best_params_
    
    Out[20]: {'svc__C': 1.0, 'svc__gamma': 0.21544346900318834}
    
    # what's the best score, in our case, roc_auc_score
    gscv.best_score_
    
    Out[22]: 0.86819366014152421
    

    注意:SVC仍然没有快速运行。计算50种可能的参数组合需要40多秒才能完成。

    %time gscv.fit(X, y)
    CPU times: user 42.6 s, sys: 959 ms, total: 43.6 s
    Wall time: 43.6 s
    

答案 1 :(得分:1)

因为特征的数量相对较少,我将从减少惩罚参数开始。它控制列车数据中错误标记样本的惩罚,并且由于您的数据包含5个特征,我猜它不是完全线性可分的。

通常,此参数(var firstChild = document.body.childNodes[0]; if(firstChild && firstChild.nodeType === 3 && firstChild.textContent.trim() === "{") { document.body.removeChild(firstChild); } )允许分类器在更高准确度的帐户上具有更大的余量(有关详细信息,请参阅this

默认情况下,C。从C=1.0开始,看看情况如何。

答案 2 :(得分:1)

一个原因可能是参数gamma不相同。

默认情况下sklearn.svm.SVC使用RBF内核,gamma为0.0,在这种情况下将使用1 / n_features。因此,gamma在不同的特征数量下会有所不同。

就建议而言,我同意Jianxun's answer