我正在尝试计算受本教程启发的股票的RSI:
我将代码重构为:
class GetStockRSI:
@classmethod
def create_stock_df(cls, stock, start_date, end_date):
df = web.get_data_yahoo(stock, start=start_date, end=end_date)
# use numerical integer index instead of date
stock_df = df.reset_index()
return stock_df
@classmethod
def compute_rsi(cls, stock, start_date, end_date):
data = cls.create_stock_df(stock, start_date, end_date)
time_window = 14
diff = data.diff(1).dropna() # diff in one field(one day)
# this preservers dimensions off diff values
up_chg = 0 * diff
down_chg = 0 * diff
# up change is equal to the positive difference, otherwise equal to zero
up_chg[diff > 0] = diff[diff > 0]
# down change is equal to negative difference, otherwise equal to zero
down_chg[diff < 0] = diff[diff < 0]
# values are related to exponential decay
# we set com=time_window-1 so we get decay alpha=1/time_window
up_chg_avg = up_chg.ewm(com=time_window - 1, min_periods=time_window).mean()
down_chg_avg = down_chg.ewm(com=time_window - 1, min_periods=time_window).mean()
rs = abs(up_chg_avg / down_chg_avg)
rsi = 100 - 100 / (1 + rs)
return rsi
@classmethod
def main(cls, stock, start_date, end_date):
rsi = cls.compute_rsi(stock=stock, start_date=start_date, end_date=end_date)
print(rsi)
计算差异会引发以下错误:
TypeError: Invalid comparison between dtype=timedelta64[ns] and int
,它引用此行:
up_chg[diff > 0] = diff[diff > 0]
我无法看到错误的来源。
数据帧示例:
Date Open High Low Close Adj Close Volume
0 2020-08-19 287.799988 291.980011 281.200012 285.609985 285.609985 459900
1 2020-08-20 284.440002 295.320007 283.390015 294.470001 294.470001 526900
2 2020-08-21 294.000000 295.790009 281.339996 284.399994 284.399994 580100
3 2020-08-24 289.500000 291.660004 275.589996 279.239990 279.239990 541100
4 2020-08-25 278.950012 295.904999 276.790009 292.790009 292.790009 852100
5 2020-08-26 295.450012 304.279999 291.329987 296.929993 296.929993 634700
6 2020-08-27 299.089996 299.089996 286.000000 291.049988 291.049988 466900
7 2020-08-28 294.000000 301.760010 290.510010 291.970001 291.970001 365200
8 2020-08-31 293.470001 298.355011 292.640015 294.630005 294.630005 420000
9 2020-09-01 296.750000 305.984985 294.651001 299.709991 299.709991 819300
10 2020-09-02 301.779999 301.779999 280.720001 296.959991 296.959991 779600
答案 0 :(得分:0)
这是因为Date
数据框中的diff
列,它是timedelta
类型的,无法与整数进行比较。
diff > 0
失败,因为diff['Date'] > 0
失败。
我认为最简单的方法是将Date
设置为原始数据帧中的索引,例如:
...
data = cls.create_stock_df(stock, start_date, end_date)
time_window = 14
data = data.set_index('Date')
...