有人可以解释在scikit learning的kmeans实现中predict()
方法的用途是什么? official documentation声明其用法为:
预测X中每个样本所属的最近簇。
但是我也可以通过在fit_transform()
方法上训练模型来获得输入集X的每个样本的簇号/标签。那么predict()
方法的用途是什么?是否应该为看不见的数据指出最接近的簇?如果是,那么如果执行降维措施(例如SVD),如何处理新的数据点?
这里是similar question,但我仍然认为这没有帮助。
答案 0 :(得分:1)
predict()方法有什么用?是否应该为看不见的数据指出最接近的簇?
是的。
那么如果执行降维措施(例如SVD),如何处理新的数据点?
您将相同的降维方法应用于看不见的数据,然后将其传递到.predict()
。这是典型的工作流程:
# prerequisites:
# x_train: training data
# x_test: "unseen" testing data
# km: initialized `KMeans()` instance
# dr: initialized dimensionality reduction instance (such as `TruncatedSVD()`)
# fitting
x_dr = dr.fit_transform(x_train)
y = km.fit_predict(x_dr)
# ...
# working with unseen data (models have been fitted before)
x_dr = dr.transform(x_test)
y = km.predict(x_dr)
# ...
实际上,诸如fit_transform
和fit_predict
之类的方法是为了方便起见。 y = km.fit_predict(x)
等同于y = km.fit(x).predict(x)
。
我认为,如果我们按如下方式编写拟合部分,则更容易看到正在发生的事情:
# fitting
dr.fit(x_train)
x_dr = dr.transform(x_train)
km.fit(x_dr)
y = km.predict(x_dr)
除了对.fit()
的调用之外,模型在拟合期间和没有可见数据的情况下同样使用。
摘要:
.fit()
的目的是用数据训练模型。.predict()
或.transform()
的目的是将训练有素的模型应用于数据。.fit_predict()
或.fit_transform()
。