Python,熊猫,计算RSI,dtype = timedelta64 [ns]与in

时间:2020-09-03 12:24:46

标签: python pandas dataframe datetime integer

我正在尝试计算受本教程启发的股票的RSI:

Compute RSI tutorial

我将代码重构为:

class GetStockRSI:
    @classmethod
    def create_stock_df(cls, stock, start_date, end_date):
        df = web.get_data_yahoo(stock, start=start_date, end=end_date)
        # use numerical integer index instead of date
        stock_df = df.reset_index()
        return stock_df

    @classmethod
    def compute_rsi(cls, stock, start_date, end_date):
        data = cls.create_stock_df(stock, start_date, end_date)
        time_window = 14
        diff = data.diff(1).dropna()  # diff in one field(one day)

        # this preservers dimensions off diff values
        up_chg = 0 * diff
        down_chg = 0 * diff

        # up change is equal to the positive difference, otherwise equal to zero
        up_chg[diff > 0] = diff[diff > 0]

        # down change is equal to negative difference, otherwise equal to zero
        down_chg[diff < 0] = diff[diff < 0]

        # values are related to exponential decay
        # we set com=time_window-1 so we get decay alpha=1/time_window
        up_chg_avg = up_chg.ewm(com=time_window - 1, min_periods=time_window).mean()
        down_chg_avg = down_chg.ewm(com=time_window - 1, min_periods=time_window).mean()

        rs = abs(up_chg_avg / down_chg_avg)
        rsi = 100 - 100 / (1 + rs)
        return rsi

    @classmethod
    def main(cls, stock, start_date, end_date):
        rsi = cls.compute_rsi(stock=stock, start_date=start_date, end_date=end_date)
        print(rsi)

计算差异会引发以下错误:

TypeError: Invalid comparison between dtype=timedelta64[ns] and int

,它引用此行: up_chg[diff > 0] = diff[diff > 0]

我无法看到错误的来源。

数据帧示例:

         Date        Open        High         Low       Close   Adj Close  Volume
0  2020-08-19  287.799988  291.980011  281.200012  285.609985  285.609985  459900
1  2020-08-20  284.440002  295.320007  283.390015  294.470001  294.470001  526900
2  2020-08-21  294.000000  295.790009  281.339996  284.399994  284.399994  580100
3  2020-08-24  289.500000  291.660004  275.589996  279.239990  279.239990  541100
4  2020-08-25  278.950012  295.904999  276.790009  292.790009  292.790009  852100
5  2020-08-26  295.450012  304.279999  291.329987  296.929993  296.929993  634700
6  2020-08-27  299.089996  299.089996  286.000000  291.049988  291.049988  466900
7  2020-08-28  294.000000  301.760010  290.510010  291.970001  291.970001  365200
8  2020-08-31  293.470001  298.355011  292.640015  294.630005  294.630005  420000
9  2020-09-01  296.750000  305.984985  294.651001  299.709991  299.709991  819300
10 2020-09-02  301.779999  301.779999  280.720001  296.959991  296.959991  779600

1 个答案:

答案 0 :(得分:0)

这是因为Date数据框中的diff列,它是timedelta类型的,无法与整数进行比较。

diff > 0失败,因为diff['Date'] > 0失败。

我认为最简单的方法是将Date设置为原始数据帧中的索引,例如:

...
data = cls.create_stock_df(stock, start_date, end_date)
time_window = 14
data = data.set_index('Date')
...