Scikit的管道-如何访问特定阶段的结果

时间:2019-06-30 00:42:26

标签: python machine-learning scikit-learn

我有以下管道:

from sklearn.pipeline import Pipeline

pipeline = Pipeline([
    ("kmeans", KMeans(n_clusters=50)),
    ("log_reg", LogisticRegression()),
])
pipeline.fit(X_train, y_train)

我想访问kmeans的标签(或Kmeans的任何其他指标)。我不知道我尝试了print(kmeans.labels_)甚至是print(pipeline.labels_),但这是行不通的,并且我收到错误消息,指出变量未定义。如何访问pipeline中特定阶段的结果?

1 个答案:

答案 0 :(得分:1)

使用latest version (0.21.2) of sklearn,您可以使用管道的__getitem__来索引步骤。

from sklearn.datasets import samples_generator
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
# generate some data to play with
X, y = samples_generator.make_classification(
    n_informative=5, n_redundant=0, random_state=42)

pipeline = Pipeline([
    ("kmeans", KMeans(n_clusters=50)),
    ("log_reg", LogisticRegression(solver='lbfgs')),
])
pipeline.fit(X, y)
pipeline['kmeans'].labels_

# array([ 2, 42, 40, 38, ...])

对于以前的版本,请使用pipeline.named_steps['kmeans']