在scikit-learn中输入fit方法

时间:2016-08-10 06:02:43

标签: python numpy scikit-learn

我正在阅读scikit-learn文档的Type Casting example

我的问题是关于ndarray操作,它是fit方法的输入。 (参考下面的代码)

>>> from sklearn import datasets
>>> from sklearn.svm import SVC
>>> iris = datasets.load_iris()
>>> clf = SVC()
>>> clf.fit(iris.data, iris.target)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

>>> list(clf.predict(iris.data[:3]))
[0, 0, 0]

>>> clf.fit(iris.data, iris.target_names[iris.target])  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

>>> list(clf.predict(iris.data[:3]))  
['setosa', 'setosa', 'setosa']

问题:在上面代码clf.fit(iris.data, iris.target_names[iris.target])的这一部分中,执行iris.target_names[iris.target]的操作是什么?

更多信息:

iris.target_names
array(['setosa', 'versicolor', 'virginica'], 
      dtype='|S10')

iris.target
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

iris.target_names[iris.target]
array(['setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa',
   'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa',
   'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa',
   'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa',
   'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa',
   'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa',
   'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa',
   'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa',
   'setosa', 'setosa', 'versicolor', 'versicolor', 'versicolor',
   'versicolor', 'versicolor', 'versicolor', 'versicolor',
   'versicolor', 'versicolor', 'versicolor', 'versicolor',
   'versicolor', 'versicolor', 'versicolor', 'versicolor',
   'versicolor', 'versicolor', 'versicolor', 'versicolor',
   'versicolor', 'versicolor', 'versicolor', 'versicolor',
   'versicolor', 'versicolor', 'versicolor', 'versicolor',
   'versicolor', 'versicolor', 'versicolor', 'versicolor',
   'versicolor', 'versicolor', 'versicolor', 'versicolor',
   'versicolor', 'versicolor', 'versicolor', 'versicolor',
   'versicolor', 'versicolor', 'versicolor', 'versicolor',
   'versicolor', 'versicolor', 'versicolor', 'versicolor',
   'versicolor', 'versicolor', 'versicolor', 'virginica', 'virginica',
   'virginica', 'virginica', 'virginica', 'virginica', 'virginica',
   'virginica', 'virginica', 'virginica', 'virginica', 'virginica',
   'virginica', 'virginica', 'virginica', 'virginica', 'virginica',
   'virginica', 'virginica', 'virginica', 'virginica', 'virginica',
   'virginica', 'virginica', 'virginica', 'virginica', 'virginica',
   'virginica', 'virginica', 'virginica', 'virginica', 'virginica',
   'virginica', 'virginica', 'virginica', 'virginica', 'virginica',
   'virginica', 'virginica', 'virginica', 'virginica', 'virginica',
   'virginica', 'virginica', 'virginica', 'virginica', 'virginica',
   'virginica', 'virginica', 'virginica'], 
  dtype='|S10')

我理解我的问题不是scikit-learn具体,而是与numpy操作的理解有关。我已经阅读了numpy文档,但我自己也无法解决这个问题。任何帮助深表感谢。感谢。

1 个答案:

答案 0 :(得分:0)

iris.target在该操作中用作index array

考虑以下数组:

arr = np.array(['a', 'b', 'c'])
arr
Out: 
array(['a', 'b', 'c'], 
      dtype='<U1')

在索引0处,它有'a':

arr[0]
Out: 'a'

在索引0和1处,它有'a'和'b':

arr[[0, 1]]
Out: 
array(['a', 'b'], 
      dtype='<U1')

这些指数可能有重复:

arr[[0, 1, 0]]
Out: 
array(['a', 'b', 'a'], 
      dtype='<U1')

在您的示例中,iris.target是一组编码标签。要获取其名称,请使用iris.target作为iris.target_names的索引,以便为每个元素提供相应的名称。