如何使用scikit预测目标标签了解

时间:2017-11-21 01:11:43

标签: python-3.x machine-learning scikit-learn neural-network

我们说我有一个数据集,我会在这个例子中提供一个玩具示例......

data = pd.DataFrame(np.random.randint(0,100,size=(100, 4)), columns=list('ABCD'))
target = "A"

...生成......

        A   B   C   D
    0  75  38  81  58
    1  36  92  80  79
    2  22  40  19  3
       ...    ...

这显然不足以提供良好的准确性,但是,我们假设我将datatarget提供给提供的random forest算法scikit learn ...

def random_forest(target, data):

    # Drop the target label, which we save separately.
    X = data.drop([target], axis=1).values
    y = data[target].values

    # Run Cross Validation on Random Forest Classifier.
    clf_tree = ske.RandomForestClassifier(n_estimators=50)
    unique_permutations_cross_val(X, y, clf_tree)

unique_permutations_cross_val只是我做的一个交叉验证函数,这是函数(它打印出模型的准确性)......

def unique_permutations_cross_val(X, y, model):

    # Split data 20/80 to be used in a K-Fold Cross Validation with unique permutations.
    shuffle_validator = model_selection.ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)

    # Calculate the score of the model after Cross Validation has been applied to it. 
    scores = model_selection.cross_val_score(model, X, y, cv=shuffle_validator)

    # Print out the score (mean), as well as the variance.
    print("Accuracy: %0.4f (+/- %0.2f)" % (scores.mean(), scores.std()))

无论如何,我的主要问题是,如何使用我创建的模型预测目标标签。例如,我们假设我提供了模型[28, 12, 33]。我希望模型预测target,在这种情况下为"A"

1 个答案:

答案 0 :(得分:0)

已发布的代码中的此模型尚未安装。您进行了交叉验证,它将告诉您模型在数据上的训练情况(或不存在),但它不适合您想要的模型对象。 cross_val_score()使用提供的模型对象的克隆来查找分数。

要预测数据,您需要在模型上显式调用fit()

因此,您可以编辑random_forest方法以返回拟合模型。像这样:

unique_permutations_cross_val(X, y, clf_tree)
clf_tree.fit(X, y)
return clf_tree

然后,只要您调用random_forest方法,就可以执行此操作:

fitted_model = random_forest(target, data)

predictions = fitted_model.predict([data to predict])