在SVM
上搜索某些教程时,我在网上找到了Support Vector Machine _ Illustration - 下面的代码,但却产生了weird
图表。在调试代码之后,我想知道原因是否在Date
列表中,确切地说:
dates.append(int(row[0].split('-')[0]))
这是我身边的静态(即2016)或者是否有其他东西,虽然我没有在代码中看到任何异常。
修改
此推论来自语法:
plt.scatter(dates, prices, color ='black', label ='Data');
plt.show()
事实上,产生垂直线,而
dates.append(int(row[0].split('-')[0]))
应该如链接中所述并反映到代码中,将每个日期YYYY-MM-DD
转换为不同的整数值
编辑(2)
将dates.append(md.datestr2num(row[0]))
替换为
dates.append(int(row[0].split('-')[0]))
中的 get_data(filename)
确实有帮助!
import csv
import numpy as np
from sklearn.svm import SVR
import matplotlib.pyplot as plt
dates = []
prices = []
def get_data(filename):
with open(filename, 'r') as csvfile:
csvFileReader = csv.reader(csvfile)
next(csvFileReader)
for row in csvFileReader:
dates.append(int(row[0].split('-')[0]))
prices.append(float(row[6])) # from 1 i.e from Opening to closing price
return
def predict_prices(dates,prices,x):
dates = np.reshape(dates,(len(dates),1))
svr_lin = SVR(kernel = 'linear', C = 1e3)
svr_poly = SVR(kernel = 'poly', C = 1e3, degree = 2)
svr_rbf = SVR(kernel = 'rbf', C = 1e3, gamma = 0.1)
svr_lin.fit(dates,prices)
svr_poly.fit(dates,prices)
svr_rbf.fit(dates,prices)
plt.scatter(dates, prices, color ='black', label ='Data')
plt.plot(dates, svr_rbf.predict(dates), color ='red', label = 'RBF model')
plt.plot(dates, svr_rbf.predict(dates), color ='green', label = 'Linear model')
plt.plot(dates, svr_rbf.predict(dates), color ='blue', label = 'Polynomial model')
plt.xlabel('Date')
plt.ylabel('Price')
plt.title('Support Vector Regression')
plt.legend
plt.show()
return svr_rbf.predict(x)[0], svr_lin.predict(x)[0], svr_poly.predict(x)[0]
get_data('C:/local/ACA.csv')
predict_prices(dates, prices, 29)
提前致谢
答案 0 :(得分:0)
get_data
创建了两个列表dates
和prices
。
np.array(dates)
和np.array(prices)
会产生什么?形状和dtype?由于您的绘图只显示一个日期,我们需要查看该数组的值范围。
我编辑了你的问题,试着让函数定义正确。确保我做对了。
csv
中的日期列是什么样的?
看起来你的dates
解析会:
In [25]: txt = '2016-02-20'
In [26]: txt.split('-')
Out[26]: ['2016', '02', '20']
In [27]: int(txt.split('-')[0])
Out[27]: 2016
所以你抓住了这一年。这将解释
处的垂直散点图In [29]: 0.010+2.01599e3
Out[29]: 2016.0
我认为这是一个更好的日期转换 - 到np.datetime64
dtype。
In [28]: np.array([txt], dtype='datetime64[D]')
Out[28]: array(['2016-02-20'], dtype='datetime64[D]')
答案 1 :(得分:0)
我一直在使用许多示例(Siraj,Chaitjo,Jaihad等)的SVM代码......并且发现Date需要采用DD-MM-YYYY格式......所以使用的数据是日期...而不是年份日期(正如dark.vapor描述的那样)。
数据只能持续30天......如下所示:
" predict_prices(日期,价格,29)"
否则使用多个月的数据文件(重复日期数字...例如1月15日和2月15日)...我每天都会获得多个价格,而不是每天只有一天的价格。
Edit2:我玩改变数据集,发现数据行可以超过29 ...只要日期只是一个整数序列。我上去了85天(行)...他们都画了。所以我对" 29"在上面的预测代码中做了什么?
能够使用多个月的大型数据文件并选择我想要测试的日期范围会很高兴...但是现在它超出了我的编码技能。
我只是一个新手编码器,所以我希望这是准确的,因为这似乎对我有用,使用DD-MM-YYYY格式工作正常并给我一个很好的清洁情节。
希望这有帮助, 罗伯特
编辑:我刚刚发现了一篇描述此代码的好文章......这证实了" day"使用DD-MM-YYYY格式解析...
https://github.com/mKausthub/stock-er
dates.append(INT(行[0] .split(' - ')[0])) " 当月的 ,因为日期的格式为[日期] - [月] - [年]。"