在scikit学习的kmeans实现中,predict()方法有什么用?

时间:2019-02-04 07:12:06

标签: scikit-learn k-means

有人可以解释在scikit learning的kmeans实现中predict()方法的用途是什么? official documentation声明其用法为:

  

预测X中每个样本所属的最近簇。

但是我也可以通过在fit_transform()方法上训练模型来获得输入集X的每个样本的簇号/标签。那么predict()方法的用途是什么?是否应该为看不见的数据指出最接近的簇?如果是,那么如果执行降维措施(例如SVD),如何处理新的数据点?

这里是similar question,但我仍然认为这没有帮助。

1 个答案:

答案 0 :(得分:1)

  

predict()方法有什么用?是否应该为看不见的数据指出最接近的簇?

是的。

  

那么如果执行降维措施(例如SVD),如何处理新的数据点?

您将相同的降维方法应用于看不见的数据,然后将其传递到.predict()。这是典型的工作流程:

# prerequisites:
#    x_train: training data
#    x_test: "unseen" testing data
#    km: initialized `KMeans()` instance
#    dr: initialized dimensionality reduction instance (such as `TruncatedSVD()`)    

# fitting
x_dr = dr.fit_transform(x_train)
y = km.fit_predict(x_dr)  

# ...

# working with unseen data (models have been fitted before)
x_dr = dr.transform(x_test)
y = km.predict(x_dr)

# ...

实际上,诸如fit_transformfit_predict之类的方法是为了方便起见。 y = km.fit_predict(x)等同于y = km.fit(x).predict(x)

我认为,如果我们按如下方式编写拟合部分,则更容易看到正在发生的事情:

# fitting
dr.fit(x_train)
x_dr = dr.transform(x_train)

km.fit(x_dr)
y = km.predict(x_dr)

除了对.fit()的调用之外,模型在拟合期间和没有可见数据的情况下同样使用。

摘要:

  • .fit()的目的是用数据训练模型。
  • .predict().transform()的目的是将训练有素的模型应用于数据。
  • 如果要在训练过程中拟合模型并将其应用于相同的数据,为方便起见,有.fit_predict().fit_transform()
  • 链接多个模型(例如降维和聚类)时,在拟合和测试过程中应以相同顺序应用它们。