MeanShift`fit`与`fit_predict`scikitlearn

时间:2014-11-27 03:05:51

标签: scikit-learn

假设X是典型形式的数组。鉴于代码。

from sklearn.cluster import MeanShift
ms = MeanShift(bin_seeding=True,cluster_all=False)
ms.fit(X)

执行此操作后,ms有两个属性:labels_cluster_centers_所以我的第一个问题是...... ms.fit_predict(X)ms.predict(X)有什么意义,因为我们已经有了一个X的分类,我们可以从labels_读取?

1 个答案:

答案 0 :(得分:1)

主要区别在于,当您说ms.fit(X)时,X是您标记的数据集/训练数据集。在说ms.fit_predict(X')时,X'是您的未标记/测试数据集。即,您正在使用fit_predict预测未标记的数据集。 即,fit(X)执行群集,而fit_predict则为群集标签。在ms.predict(X)对象上没有sklearn.cluster.mean_shift_.MeanShift之类的东西。 另请参阅下面的dir(ms)

>>> help(ms.fit)
Help on method fit in module sklearn.cluster.mean_shift_:

fit(self, X) method of sklearn.cluster.mean_shift_.MeanShift instance
    Perform clustering.

    Parameters
    -----------
    X : array-like, shape=[n_samples, n_features]
        Samples to cluster.

>>> help(ms.fit_predict)
Help on method fit_predict in module sklearn.base:

fit_predict(self, X, y=None) method of sklearn.cluster.mean_shift_.MeanShift instance
    Performs clustering on X and returns cluster labels.

    Parameters
    ----------
    X : ndarray, shape (n_samples, n_features)
        Input data.

    Returns
    -------
    y : ndarray, shape (n_samples,)
        cluster labels


dir(ms)
['__class__', '__delattr__', '__dict__', '__doc__', '__format__', '__getattribute__', '__hash__', '__init__', '__module__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_get_param_names', 'bandwidth', 'bin_seeding', 'cluster_all', 'fit', 'fit_predict', 'get_params', 'min_bin_freq', 'seeds', 'set_params']

ms属性为_labels& _cluster_centersX数据,您可以使用标准误分类惩罚技术估算模型的优劣。您无法使用fit_predict进行估算,因为您只会获得标签,而不会获得群集中心。因此,您可以根据自己的良好标准来设计群集中心。