Scikit学习:了解MeanShift fit_predict()

时间:2019-07-13 06:07:41

标签: python numpy opencv scikit-learn computer-vision

我正在使用scikit-learn的Mean Shift算法执行图像分割。我有以下代码:

tabBarOptions: {
  style: {
    // Neither of these works.
    //marginLeft: 50,
    //paddingLeft: 50
  }
}

我有一个扁平的颜色矩阵,其尺寸为3x994994,因此总共有2984982个样本。

import cv2
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
from itertools import cycle
from PIL import Image

image = Image.open('sample_images/fruit.png').convert('RGB')
image = np.array(image)

red = image[:,:,0]
green = image[:,:,1]
blue = image[:,:,2]

flat_red = red.flatten()
flat_green = green.flatten()
flat_blue = blue.flatten()

flattened = np.stack((flat_red, flat_green, flat_blue))

ms_clf = MeanShift(bin_seeding=True)
ms_labels = ms_clf.fit_predict(flattened)
plt.imshow(np.reshape(ms_labels, [1001, 994]))

此扁平化的矩阵用作MeanShift fit_predict()函数的输入。当我尝试打印由fit_predict()返回的标签数组时,得到以下输出:

print(flattened.shape)
(3, 994994)

print(flattened)
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]

fit_predict()函数是否不会为每个数据样本返回标签?为什么我只得到其中包含3个元素的数组?任何见解都会受到赞赏。

1 个答案:

答案 0 :(得分:0)

fit_predict()的文档说,它以形状X(n_samples,n_features)作为输入,并返回形状标签(n_samples,)。由于您正在输入3x994994数组,其中n_samples = 3,n_features = 994994,因此,如您所见,这意味着标签将是(3,)数组。本质上是将每个图像通道“扁平化”为一条数据。