我有使用numpy数组做支持向量机的以下问题。
import numpy as np
from sklearn import svm
我有3个班级/标签(male
,female
,na
),表示如下:
labels = [0,1,2]
每个班级由3个变量(height
,weight
,age
)定义为培训数据:
male_height = np.array([111,121,137,143,157])
male_weight = np.array([60,70,88,99,75])
male_age = np.array([41,32,73,54,35])
males = np.hstack([male_height,male_weight,male_age])
female_height = np.array([91,121,135,98,90])
female_weight = np.array([32,67,98,86,56])
female_age = np.array([51,35,33,67,61])
females = np.hstack([female_height,female_weight,female_age])
na_height = np.array([96,127,145,99,91])
na_weight = np.array([42,97,78,76,86])
na_age = np.array([56,35,49,64,66])
nas = np.hstack([na_height,na_weight,na_age])
现在我必须使用支持向量机方法来训练数据来预测给定这三个变量的类:
height_weight_age = [100,100,100]
clf = svm.SVC()
trainingData = np.vstack([males,females,nas])
clf.fit(trainingData, labels)
result = clf.predict(height_weight_age)
print result
不幸的是,发生以下错误:
ValueError: X.shape[1] = 3 should be equal to 15, the number of features at training time
我应该如何修改trainingData
和labels
以获得正确答案?
答案 0 :(得分:2)
hstack
给出1-d数组。你需要2-d形状(n_samples, n_features)
的数组,你可以从vstack
得到它。
In [7]: males = np.hstack([male_height,male_weight,male_age])
In [8]: males
Out[8]:
array([111, 121, 137, 143, 157, 60, 70, 88, 99, 75, 41, 32, 73,
54, 35])
In [9]: np.vstack([male_height,male_weight,male_age])
Out[9]:
array([[111, 121, 137, 143, 157],
[ 60, 70, 88, 99, 75],
[ 41, 32, 73, 54, 35]])
In [10]: np.vstack([male_height,male_weight,male_age]).T
Out[10]:
array([[111, 60, 41],
[121, 70, 32],
[137, 88, 73],
[143, 99, 54],
[157, 75, 35]])
您还需要传递反映每个样本标签的标签/数组,而不是仅仅枚举存在的标签。在修复所有变量之后,我可以训练SVM并按如下方式应用它:
In [19]: clf = svm.SVC()
In [20]: y = ["male"] * 5 + ["female"] * 5 + ["na"] * 5
In [21]: X = np.vstack([males, females, nas])
In [22]: clf.fit(X, y)
Out[22]:
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
In [23]: height_weight_age = [100,100,100]
In [24]: clf.predict(height_weight_age)
Out[24]:
array(['female'],
dtype='|S6')
(请注意,我使用字符串标签而不是数字标签。我还建议您标准化要素值,因为它们的范围相当不同。)