我是ML及其概念的新手,我尝试使用sklearn实现SVR以解决住房价格问题。当我适应模型时,我收到了这个错误
'type 'exceptions.ValueError'>, ValueError("Mix type of y not allowed, got types set(['continuous', 'multiclass'])",), <traceback object at 0x000000001493E388>v'
这是我的简单尝试:
import numpy as np
import urllib
import traceback
import sys
import os
import pandas as pd
from sklearn import svm
import matplotlib.pyplot as plt
from sklearn import cross_validation
from sklearn import metrics
import numpy as np
try:
dataset=('1000_home.csv')
data=pd.read_csv(dataset,header=0)
print(data.shape)
print(data.head())
feature_col=['bedrooms','bathrooms','sqft_living','sqft_lot','floors']
x=data.drop('price',axis=1)
y=data.price
x=np.array(x)# trying this to avoid the erro
y=np.array(y)
print(x.shape)
print(y.shape)
x_train,x_test,y_train,y_test=
cross_validation.train_test_split(x,y,test_size=0.3)
print(x_train.shape, " ",x_test.shape)
print(y_train.shape,' ', y_test.shape)
print(type(y))
lm=svm.SVR(kernel='linear')
lm.fit(x_train,y_train)
y_pred=lm.predict(x_test)
print(metrics.classification_report(y_test,y_pred))
print(metrics.confusion_matrix(y_test,y_pred))
# plt.show()
#print(lm.intercept_)
# print(zip(feature_col,lm.coef_))
#plt.scatter(data.sqft_living,data.price)
# plt.show()
except:
print("error")
e=sys.exc_info()
print(e)
这是我的数据样本,其中价格是目标(y):
price bedrooms bathrooms sqft_living sqft_lot floors
221900 3 1 1180 5650 1
538000 3 2.25 2570 7242 2
180000 2 1 770 10000 1
604000 4 3 1960 5000 1
510000 3 2 1680 8080 1
1225000 4 4.5 5420 101930 1
257500 3 2.25 1715 6819 2
291850 3 1.5 1060 9711 1
229500 3 1 1780 7470 1
323000 3 2.5 1890 6560 2
谢谢
答案 0 :(得分:1)
您收到此错误的原因是您使用的是一种用于分类而非回归的指标。
sklearn有一个很棒的页面,显示了可用于分类,聚类和回归的不同类型的指标。您会发现classification_report用于分类。 http://scikit-learn.org/stable/modules/model_evaluation.html
对于svr(回归)而不是SVC(分类)的支持向量机,您需要使用r平方(http://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html#sklearn.metrics.r2_score),均方误差(http://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html#sklearn.metrics.mean_squared_error)或解释方差( http://scikit-learn.org/stable/modules/generated/sklearn.metrics.explained_variance_score.html#sklearn.metrics.explained_variance_score)。
答案 1 :(得分:0)
根据错误,它认为您的y
(price
)同时具有多类和连续类型变量。从您给出的样本看起来是连续的。我会验证您的文件格式不正确,并且只有价格列的连续值