Python ta-lib与pandas.io.data:烛台没有绘图,但其他图表都可以

时间:2015-02-25 14:49:52

标签: python pandas candlestick-chart ta-lib

iPython 2.3.1,OS-X Yosemite 10.10.2
Python print(sys.version):
2.7.6(默认,2014年9月9日,15:04:36)
[GCC 4.2.1兼容的Apple LLVM 6.0(clang-600.0.39)]

以下代码适用于为美国股票数据提取的数据,例如制作安全ID" INTC"对于英特尔但是,当我访问欧洲股票的数据时,即使数据框中存在所有OHLC数据,烛台功能也会失败。已经将完整的代码放在这里,以显示其他技术分析图表对欧洲股票数据的情况很好。

import pandas.io.data as web
import pandas as pd
import numpy as np
import talib as ta
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.dates import date2num
from matplotlib.finance import candlestick
import datetime

ticker = 'DNO.L'

# Download sample data
sec_id = web.get_data_yahoo(ticker, '2014-06-01')

# Data for matplotlib finance plot
sec_id_ochl = np.array(pd.DataFrame({'0':date2num(sec_id.index),
                                  '1':sec_id.Open,
                                  '2':sec_id.Close,
                                  '3':sec_id.High,
                                  '4':sec_id.Low}))

# Technical Analysis
SMA_FAST = 50
SMA_SLOW = 200
RSI_PERIOD = 14
RSI_AVG_PERIOD = 15
MACD_FAST = 12
MACD_SLOW = 26
MACD_SIGNAL = 9
STOCH_K = 14
STOCH_D = 3
SIGNAL_TOL = 3
Y_AXIS_SIZE = 12

analysis = pd.DataFrame(index = sec_id.index)

analysis['sma_f'] = pd.rolling_mean(sec_id.Close, SMA_FAST)
analysis['sma_s'] = pd.rolling_mean(sec_id.Close, SMA_SLOW)
analysis['rsi'] = ta.RSI(sec_id.Close.as_matrix(), RSI_PERIOD)
analysis['sma_r'] = pd.rolling_mean(analysis.rsi, RSI_AVG_PERIOD) # check shift
analysis['macd'], analysis['macdSignal'], analysis['macdHist'] = \
    ta.MACD(sec_id.Close.as_matrix(), fastperiod=MACD_FAST, slowperiod=MACD_SLOW, signalperiod=MACD_SIGNAL)
analysis['stoch_k'], analysis['stoch_d'] = \
    ta.STOCH(sec_id.High.as_matrix(), sec_id.Low.as_matrix(), sec_id.Close.as_matrix(), slowk_period=STOCH_K, slowd_period=STOCH_D)

analysis['sma'] = np.where(analysis.sma_f > analysis.sma_s, 1, 0)
analysis['macd_test'] = np.where((analysis.macd > analysis.macdSignal), 1, 0)
analysis['stoch_k_test'] = np.where((analysis.stoch_k < 50) & (analysis.stoch_k > analysis.stoch_k.shift(1)), 1, 0)
analysis['rsi_test'] = np.where((analysis.rsi < 50) & (analysis.rsi > analysis.rsi.shift(1)), 1, 0)

# Prepare plot
fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True)
ax1.set_ylabel(ticker, size=20)

#size plot
fig.set_size_inches(15,30)

# Plot candles
candlestick(ax1, sec_id_ochl, width=0.5, colorup='g', colordown='r', alpha=1)

# Draw Moving Averages
analysis.sma_f.plot(ax=ax1, c='r')
analysis.sma_s.plot(ax=ax1, c='g')

#RSI
ax2.set_ylabel('RSI', size=Y_AXIS_SIZE)
analysis.rsi.plot(ax = ax2, c='g', label = 'Period: ' + str(RSI_PERIOD))        
analysis.sma_r.plot(ax = ax2, c='r', label = 'MA: ' + str(RSI_AVG_PERIOD))
ax2.axhline(y=30, c='b')
ax2.axhline(y=50, c='black')
ax2.axhline(y=70, c='b')
ax2.set_ylim([0,100])
handles, labels = ax2.get_legend_handles_labels()
ax2.legend(handles, labels)

# Draw MACD computed with Talib
ax3.set_ylabel('MACD: '+ str(MACD_FAST) + ', ' + str(MACD_SLOW) + ', ' + str(MACD_SIGNAL), size=Y_AXIS_SIZE)
analysis.macd.plot(ax=ax3, color='b', label='Macd')
analysis.macdSignal.plot(ax=ax3, color='g', label='Signal')
analysis.macdHist.plot(ax=ax3, color='r', label='Hist')
ax3.axhline(0, lw=2, color='0')
handles, labels = ax3.get_legend_handles_labels()
ax3.legend(handles, labels)

# Stochastic plot
ax4.set_ylabel('Stoch (k,d)', size=Y_AXIS_SIZE)
analysis.stoch_k.plot(ax=ax4, label='stoch_k:'+ str(STOCH_K), color='r')
analysis.stoch_d.plot(ax=ax4, label='stoch_d:'+ str(STOCH_D), color='g')
handles, labels = ax4.get_legend_handles_labels()
ax4.legend(handles, labels)
ax4.axhline(y=20, c='b')
ax4.axhline(y=50, c='black')
ax4.axhline(y=80, c='b')

plt.show()

plot

2 个答案:

答案 0 :(得分:2)

我复制了您的代码并遇到了与您相同的错误。我修好了,但需要几步。我在这里复制了固定代码,以及我为使代码运行而没有错误所做的其他步骤。

  1. 由于Yahoo API的更改,修复了pandas_datareader的问题。我从another StackOverflow post得到了这个解决方案。

    git clone https://github.com/pydata/pandas-datareader
    cd pandas-datareader/
    git remote add rgkimball http://github.com/rgkimball/pandas-datareader
    git fetch rgkimball fix-yahoo
    git checkout fix-yahoo
    pip2 uninstall pandas_datareader
    python setup.py install
    
  2. 我还使用@ndrw上面提交的答案在代码中修改了第18行。

  3. 我做了一些其他更改,我已将其包含在此Github repo中。为方便起见,我复制了以下代码。我希望它有所帮助!

    from pandas_datareader import data
    import pandas as pd
    import numpy as np
    import talib as ta
    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    import matplotlib.gridspec as gridspec
    from matplotlib.dates import date2num
    from matplotlib.finance import candlestick_ohlc as candlestick
    import datetime
    
    ticker = 'OPK'
    
    # Download sample data
    sec_id = data.get_data_google(ticker, '2014-06-01')
    
    # Data for matplotlib finance plot
    sec_id_ochl = np.array(pd.DataFrame({'0':date2num(sec_id.index.to_pydatetime()),
            '1':sec_id.Open,
            '2':sec_id.Close,
            '3':sec_id.High,
            '4':sec_id.Low}))
    
    # Technical Analysis
    SMA_FAST = 50
    SMA_SLOW = 200
    RSI_PERIOD = 14
    RSI_AVG_PERIOD = 15
    MACD_FAST = 12
    MACD_SLOW = 26
    MACD_SIGNAL = 9
    STOCH_K = 14
    STOCH_D = 3
    SIGNAL_TOL = 3
    Y_AXIS_SIZE = 12
    
    analysis = pd.DataFrame(index = sec_id.index)
    
    analysis['sma_f'] = pd.rolling_mean(sec_id.Close, SMA_FAST)
    analysis['sma_s'] = pd.rolling_mean(sec_id.Close, SMA_SLOW)
    analysis['rsi'] = ta.RSI(sec_id.Close.as_matrix(), RSI_PERIOD)
    analysis['sma_r'] = pd.rolling_mean(analysis.rsi, RSI_AVG_PERIOD) # check shift
    analysis['macd'], analysis['macdSignal'], analysis['macdHist'] = ta.MACD(sec_id.Close.as_matrix(), fastperiod=MACD_FAST, slowperiod=MACD_SLOW, signalperiod=MACD_SIGNAL)
    analysis['stoch_k'], analysis['stoch_d'] = ta.STOCH(sec_id.High.as_matrix(), sec_id.Low.as_matrix(), sec_id.Close.as_matrix(), slowk_period=STOCH_K, slowd_period=STOCH_D)
    
    analysis['sma'] = np.where(analysis.sma_f > analysis.sma_s, 1, 0)
    analysis['macd_test'] = np.where((analysis.macd > analysis.macdSignal), 1, 0)
    analysis['stoch_k_test'] = np.where((analysis.stoch_k < 50) & (analysis.stoch_k > analysis.stoch_k.shift(1)), 1, 0)
    analysis['rsi_test'] = np.where((analysis.rsi < 50) & (analysis.rsi > analysis.rsi.shift(1)), 1, 0)
    
    # Prepare plot
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True)
    ax1.set_ylabel(ticker, size=20)
    
    #size plot
    fig.set_size_inches(15,30)
    
    # Plot candles
    candlestick(ax1, sec_id_ochl, width=0.5, colorup='g', colordown='r', alpha=1)
    
    # Draw Moving Averages
    analysis.sma_f.plot(ax=ax1, c='r')
    analysis.sma_s.plot(ax=ax1, c='g')
    
    #RSI
    ax2.set_ylabel('RSI', size=Y_AXIS_SIZE)
    analysis.rsi.plot(ax = ax2, c='g', label = 'Period: ' + str(RSI_PERIOD))
    analysis.sma_r.plot(ax = ax2, c='r', label = 'MA: ' + str(RSI_AVG_PERIOD))
    ax2.axhline(y=30, c='b')
    ax2.axhline(y=50, c='black')
    ax2.axhline(y=70, c='b')
    ax2.set_ylim([0,100])
    handles, labels = ax2.get_legend_handles_labels()
    ax2.legend(handles, labels)
    
    # Draw MACD computed with Talib
    ax3.set_ylabel('MACD: '+ str(MACD_FAST) + ', ' + str(MACD_SLOW) + ', ' + str(MACD_SIGNAL), size=Y_AXIS_SIZE)
    analysis.macd.plot(ax=ax3, color='b', label='Macd')
    analysis.macdSignal.plot(ax=ax3, color='g', label='Signal')
    analysis.macdHist.plot(ax=ax3, color='r', label='Hist')
    ax3.axhline(0, lw=2, color='0')
    handles, labels = ax3.get_legend_handles_labels()
    ax3.legend(handles, labels)
    
    # Stochastic plot
    ax4.set_ylabel('Stoch (k,d)', size=Y_AXIS_SIZE)
    analysis.stoch_k.plot(ax=ax4, label='stoch_k:'+ str(STOCH_K), color='r')
    analysis.stoch_d.plot(ax=ax4, label='stoch_d:'+ str(STOCH_D), color='g')
    handles, labels = ax4.get_legend_handles_labels()
    ax4.legend(handles, labels)
    ax4.axhline(y=20, c='b')
    ax4.axhline(y=50, c='black')
    ax4.axhline(y=80, c='b')
    
    plt.show()
    
  4. 以下是我得到的输出:enter image description here

答案 1 :(得分:1)

pandas data.index需要转换datetype。

import matplotlib.dates as mdates

...

在您的代码中,在sec_id_ochl:

之前
# Data for matplotlib finance plot
sec_id.index  = mdates.date2num(sec_id.index.to_pydatetime())