在这种情况下我应该使用哪种分类器或ML SDK?

时间:2016-05-18 04:28:36

标签: machine-learning scikit-learn random-forest apache-spark-mllib xgboost

训练数据(包括训练和验证集)大约有80个样本,每个样本都有200个密集浮点。有6标记的classe,它们是不平衡的。

在常用的ML库中(例如,libsvmscikit-learnSpark MLlibrandom forestXGBoost或其他),我应该使用哪些?关于硬件配置,机器具有24个CPU核心和250 Gb内存。

1 个答案:

答案 0 :(得分:1)

我建议使用scikit-learn SGDClassifier,因为它是在线的,因此您可以将训练数据以块(迷你批次)加载到内存中并逐渐训练分类器,以便您赢得“#”。需要将所有数据加载到内存中。

高度并行且易于使用。 您可以将 warm_start 参数设置为True,并在每个X,y加载到内存中时多次调用fit,或者您可以使用partial_fit方法的更好选项。

clf = SGDClassifier(loss='hinge', alpha=1e-4, penalty='l2', l1_ratio=0.9, learning_rate='optimal', n_iter=10, shuffle=False, n_jobs=10, fit_intercept=True)
# len(classes) = n_classes
all_classes = np.array(set_of_all_classes)
while True:
    #load a minibatch from disk into memory
    X, y = load_next_chunk()
    clf.partial_fit(X, y, all_classes) 
X_test, y_test = load_test_data()    
y_pred = clf.predict(X_test)