支持Vector Machine Python 3.5.2

时间:2017-01-14 23:52:19

标签: python numpy matplotlib machine-learning svm

SVM上搜索某些教程时,我在网上找到了Support Vector Machine _ Illustration - 下面的代码,但却产生了weird图表。在调试代码之后,我想知道原因是否在Date列表中,确切地说:

dates.append(int(row[0].split('-')[0]))

这是我身边的静态(即2016)或者是否有其他东西,虽然我没有在代码中看到任何异常。

修改

此推论来自语法:

plt.scatter(dates, prices, color ='black', label ='Data'); 
plt.show()
事实上,

产生垂直线,而

dates.append(int(row[0].split('-')[0]))

应该如链接中所述并反映到代码中,将每个日期YYYY-MM-DD转换为不同的整数值

编辑(2)

dates.append(md.datestr2num(row[0]))替换为

函数dates.append(int(row[0].split('-')[0]))中的

get_data(filename)确实有帮助!

enter image description here

import csv
import numpy as np
from sklearn.svm import SVR
import matplotlib.pyplot as plt

dates = []
prices = []

def get_data(filename):
    with open(filename, 'r') as csvfile:
        csvFileReader = csv.reader(csvfile)
        next(csvFileReader)
        for row in csvFileReader:
            dates.append(int(row[0].split('-')[0]))
            prices.append(float(row[6]))  # from 1 i.e from Opening to closing price

    return

def predict_prices(dates,prices,x):
    dates = np.reshape(dates,(len(dates),1))
    svr_lin = SVR(kernel = 'linear', C = 1e3)
    svr_poly = SVR(kernel = 'poly', C = 1e3, degree = 2)
    svr_rbf = SVR(kernel = 'rbf',  C = 1e3, gamma = 0.1)  

    svr_lin.fit(dates,prices)
    svr_poly.fit(dates,prices)
    svr_rbf.fit(dates,prices)

    plt.scatter(dates, prices, color ='black', label ='Data')
    plt.plot(dates, svr_rbf.predict(dates), color ='red', label = 'RBF model')
    plt.plot(dates, svr_rbf.predict(dates), color ='green', label = 'Linear model')
    plt.plot(dates, svr_rbf.predict(dates), color ='blue', label = 'Polynomial model')
    plt.xlabel('Date')
    plt.ylabel('Price')
    plt.title('Support Vector Regression')
    plt.legend
    plt.show()
    return svr_rbf.predict(x)[0], svr_lin.predict(x)[0], svr_poly.predict(x)[0]

get_data('C:/local/ACA.csv')
predict_prices(dates, prices, 29)

提前致谢

2 个答案:

答案 0 :(得分:0)

get_data创建了两个列表datesprices

np.array(dates)np.array(prices)会产生什么?形状和dtype?由于您的绘图只显示一个日期,我们需要查看该数组的值范围。

我编辑了你的问题,试着让函数定义正确。确保我做对了。

csv中的日期列是什么样的?

看起来你的dates解析会:

In [25]: txt = '2016-02-20'

In [26]: txt.split('-')
Out[26]: ['2016', '02', '20']

In [27]: int(txt.split('-')[0])
Out[27]: 2016

所以你抓住了这一年。这将解释

处的垂直散点图
In [29]: 0.010+2.01599e3
Out[29]: 2016.0

我认为这是一个更好的日期转换 - 到np.datetime64 dtype。

In [28]: np.array([txt], dtype='datetime64[D]')
Out[28]: array(['2016-02-20'], dtype='datetime64[D]')

答案 1 :(得分:0)

我一直在使用许多示例(Siraj,Chaitjo,Jaihad等)的SVM代码......并且发现Date需要采用DD-MM-YYYY格式......所以使用的数据是日期...而不是年份日期(正如dark.vapor描述的那样)。

数据只能持续30天......如下所示:

" predict_prices(日期,价格,29)"

否则使用多个月的数据文件(重复日期数字...例如1月15日和2月15日)...我每天都会获得多个价格,而不是每天只有一天的价格。

Edit2:我玩改变数据集,发现数据行可以超过29 ...只要日期只是一个整数序列。我上去了85天(行)...他们都画了。所以我对" 29"在上面的预测代码中做了什么?

能够使用多个月的大型数据文件并选择我想要测试的日期范围会很高兴...但是现在它超出了我的编码技能。

我只是一个新手编码器,所以我希望这是准确的,因为这似乎对我有用,使用DD-MM-YYYY格式工作正常并给我一个很好的清洁情节。

希望这有帮助, 罗伯特

编辑:我刚刚发现了一篇描述此代码的好文章......这证实了" day"使用DD-MM-YYYY格式解析...

https://github.com/mKausthub/stock-er

dates.append(INT(行[0] .split(' - ')[0])) " 当月的 ,因为日期的格式为[日期] - [月] - [年]。"