我有一个像以下的Pandas数据框:
ticket date close
0 AAA 2018-01-12 176.16
1 AAA 2018-01-13 176.49
3 AAA 2018-01-14 176.00
4 BBB 2018-01-12 78.19
5 BBB 2018-01-13 79.90
6 BBB 2018-01-14 78.10
我有一个功能:
def rsi(dataframe, period, column = 'close'):
delta = dataframe[column].diff()
up, down = delta.copy(), delta.copy()
up[up < 0] = 0
down[down > 0] = 0
rolling_up = up.ewm(com=period - 1, adjust=False).mean()
rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
rsi = 100 - 100 / (1 + rolling_up / rolling_down)
dataframe['rsi'] = rsi
return dataframe
我需要的是将此功能应用于每个groupby的数据帧(&#39; ticket&#39;)。我试过这个,但它没有用。请给我一些建议。
print(dataframe.groupby('ticket').apply(rsi, 2))
我收到错误:
无法从重复轴重新索引
# -*- coding: utf-8 -*-
import json
import pandas
import requests
import datetime
def get_historical_prices(tickets, range):
request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
united_dataframe = pandas.DataFrame()
for symbol in json:
ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
ticket_dataframe.insert(0, 'ticket', symbol)
united_dataframe = united_dataframe.append(ticket_dataframe)
return united_dataframe[['ticket', 'date', 'close']]
def rsi(dataframe, period, column = 'close'):
delta = all_prices[column].diff()
up, down = delta.copy(), delta.copy()
up[up < 0] = 0
down[down > 0] = 0
rolling_up = up.ewm(com=period - 1, adjust=False).mean()
rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
rsi = 100 - 100 / (1 + rolling_up / rolling_down)
dataframe['rsi'] = rsi
return dataframe
# Get the data
tickets = ['AAPL', 'FB', 'TSLA']
all_prices = get_historical_prices(tickets, '1m')
print(all_prices.groupby('ticket').apply(rsi, 2))
答案 0 :(得分:1)
源代码中存在问题。这条线
delta = all_prices[column].diff()
应该是
delta = dataframe[column].diff()
修复它也会毫无问题地运行。重新分配会将rsi
列添加到all_prices
即。
all_prices = all_prices.groupby('ticket').apply(rsi, 2)
所以最终的鳕鱼和结果如下所示
In [20]: # -*- coding: utf-8 -*-
...:
...: import json
...: import pandas
...: import requests
...: import datetime
...:
...: def get_historical_prices(tickets, range):
...: request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
...: json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
...: united_dataframe = pandas.DataFrame()
...: for symbol in json:
...: ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
...: ticket_dataframe.insert(0, 'ticket', symbol)
...: united_dataframe = united_dataframe.append(ticket_dataframe)
...: return united_dataframe[['ticket', 'date', 'close']]
...:
...: def rsi(dataframe, period, column = 'close'):
...: delta = dataframe[column].diff()
...: up, down = delta.copy(), delta.copy()
...: up[up < 0] = 0
...: down[down > 0] = 0
...: rolling_up = up.ewm(com=period - 1, adjust=False).mean()
...: rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
...: rsi = 100 - 100 / (1 + rolling_up / rolling_down)
...: dataframe['rsi'] = rsi
...: return dataframe
...:
...: # Get the data
...: tickets = ['AAPL', 'FB', 'TSLA']
...: all_prices = get_historical_prices(tickets, '1m')
...:
...: all_prices = all_prices.groupby('ticket').apply(rsi, 2)
...: print(all_prices.head())
...:
...:
ticket date close rsi
0 AAPL 2018-01-12 177.09 NaN
1 AAPL 2018-01-16 176.19 0.000000
2 AAPL 2018-01-17 179.10 76.377953
3 AAPL 2018-01-18 179.26 78.208232
4 AAPL 2018-01-19 178.46 44.065484
答案 1 :(得分:0)
此处的问题与行
有关dataframe['rsi'] = rsi
return dataframe
问题是rsi没有与dataframe相同的索引,更多的rsi有不同的长度
我将上面的行更改为
return rsi
并且代码运行没有问题
所以最终的鳕鱼和结果如下所示
In [12]: # -*- coding: utf-8 -*-
...:
...: import json
...: import pandas
...: import requests
...: import datetime
...:
...: def get_historical_prices(tickets, range):
...: request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
...: json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
...: united_dataframe = pandas.DataFrame()
...: for symbol in json:
...: ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
...: ticket_dataframe.insert(0, 'ticket', symbol)
...: united_dataframe = united_dataframe.append(ticket_dataframe)
...: return united_dataframe[['ticket', 'date', 'close']]
...:
...: def rsi(dataframe, period, column = 'close'):
...: delta = all_prices[column].diff()
...: up, down = delta.copy(), delta.copy()
...: up[up < 0] = 0
...: down[down > 0] = 0
...: rolling_up = up.ewm(com=period - 1, adjust=False).mean()
...: rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
...: rsi = 100 - 100 / (1 + rolling_up / rolling_down)
...:
...: return rsi
...:
...: # Get the data
...: tickets = ['AAPL', 'FB', 'TSLA']
...: all_prices = get_historical_prices(tickets, '1m')
...:
...: print(all_prices.groupby('ticket').apply(rsi, 2))
...:
...:
close 0 1 2 3 4 5 6 \
ticket
AAPL NaN 0.0 76.377953 78.208232 44.065484 16.991057 19.694656
FB NaN 0.0 76.377953 78.208232 44.065484 16.991057 19.694656
TSLA NaN 0.0 76.377953 78.208232 44.065484 16.991057 19.694656
close 7 8 9 ... 11 12 \
ticket ...
AAPL 3.521704 1.252711 15.2917 ... 76.523444 48.103572
FB 3.521704 1.252711 15.2917 ... 76.523444 48.103572
TSLA 3.521704 1.252711 15.2917 ... 76.523444 48.103572
close 13 14 15 16 17 18 \
ticket
AAPL 80.777497 46.146025 23.884925 8.34273 16.897125 75.919398
FB 80.777497 46.146025 23.884925 8.34273 16.897125 75.919398
TSLA 80.777497 46.146025 23.884925 8.34273 16.897125 75.919398
close 19 20
ticket
AAPL 15.705838 12.501725
FB 15.705838 12.501725
TSLA 15.705838 12.501725
[3 rows x 61 columns]