正确将自己的函数应用于分组的pandas数据帧

时间:2018-02-12 09:42:32

标签: python function pandas dataframe

我有一个像以下的Pandas数据框:

   ticket date         close  
0    AAA  2018-01-12  176.16
1    AAA  2018-01-13  176.49
3    AAA  2018-01-14  176.00
4    BBB  2018-01-12  78.19
5    BBB  2018-01-13  79.90
6    BBB  2018-01-14  78.10

我有一个功能:

def rsi(dataframe, period, column = 'close'):
    delta = dataframe[column].diff()
    up, down = delta.copy(), delta.copy()
    up[up < 0] = 0
    down[down > 0] = 0
    rolling_up = up.ewm(com=period - 1, adjust=False).mean()
    rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
    rsi = 100 - 100 / (1 + rolling_up / rolling_down)
    dataframe['rsi'] =  rsi
    return dataframe

我需要的是将此功能应用于每个groupby的数据帧(&#39; ticket&#39;)。我试过这个,但它没有用。请给我一些建议。

print(dataframe.groupby('ticket').apply(rsi, 2))

我收到错误:

  

无法从重复轴重新索引

整个源代码是:

# -*- coding: utf-8 -*-

import json
import pandas
import requests
import datetime

def get_historical_prices(tickets, range):
    request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
    json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
    united_dataframe = pandas.DataFrame()
    for symbol in json:
        ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
        ticket_dataframe.insert(0, 'ticket', symbol)
        united_dataframe = united_dataframe.append(ticket_dataframe)
    return united_dataframe[['ticket', 'date', 'close']]

def rsi(dataframe, period, column = 'close'):
    delta = all_prices[column].diff()
    up, down = delta.copy(), delta.copy()
    up[up < 0] = 0
    down[down > 0] = 0
    rolling_up = up.ewm(com=period - 1, adjust=False).mean()
    rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
    rsi = 100 - 100 / (1 + rolling_up / rolling_down)
    dataframe['rsi'] =  rsi
    return dataframe

# Get the data
tickets = ['AAPL', 'FB', 'TSLA']
all_prices = get_historical_prices(tickets, '1m')

print(all_prices.groupby('ticket').apply(rsi, 2))

2 个答案:

答案 0 :(得分:1)

源代码中存在问题。这条线

delta = all_prices[column].diff()

应该是

delta = dataframe[column].diff() 

修复它也会毫无问题地运行。重新分配会将rsi列添加到all_prices 即。

all_prices = all_prices.groupby('ticket').apply(rsi, 2)

所以最终的鳕鱼和结果如下所示

In [20]: # -*- coding: utf-8 -*-
    ...: 
    ...: import json
    ...: import pandas
    ...: import requests
    ...: import datetime
    ...: 
    ...: def get_historical_prices(tickets, range):
    ...:     request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
    ...:     json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
    ...:     united_dataframe = pandas.DataFrame()
    ...:     for symbol in json:
    ...:         ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
    ...:         ticket_dataframe.insert(0, 'ticket', symbol)
    ...:         united_dataframe = united_dataframe.append(ticket_dataframe)
    ...:     return united_dataframe[['ticket', 'date', 'close']]
    ...: 
    ...: def rsi(dataframe, period, column = 'close'):
    ...:     delta = dataframe[column].diff()
    ...:     up, down = delta.copy(), delta.copy()
    ...:     up[up < 0] = 0
    ...:     down[down > 0] = 0
    ...:     rolling_up = up.ewm(com=period - 1, adjust=False).mean()
    ...:     rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
    ...:     rsi = 100 - 100 / (1 + rolling_up / rolling_down)
    ...:     dataframe['rsi'] = rsi
    ...:     return dataframe
    ...: 
    ...: # Get the data
    ...: tickets = ['AAPL', 'FB', 'TSLA']
    ...: all_prices = get_historical_prices(tickets, '1m')
    ...: 
    ...: all_prices = all_prices.groupby('ticket').apply(rsi, 2)
    ...: print(all_prices.head())
    ...: 
    ...: 
  ticket        date   close        rsi
0   AAPL  2018-01-12  177.09        NaN
1   AAPL  2018-01-16  176.19   0.000000
2   AAPL  2018-01-17  179.10  76.377953
3   AAPL  2018-01-18  179.26  78.208232
4   AAPL  2018-01-19  178.46  44.065484

答案 1 :(得分:0)

此处的问题与行

有关
dataframe['rsi'] =  rsi
return dataframe

问题是rsi没有与dataframe相同的索引,更多的rsi有不同的长度

我将上面的行更改为

return rsi

并且代码运行没有问题

所以最终的鳕鱼和结果如下所示

In [12]: # -*- coding: utf-8 -*-
    ...: 
    ...: import json
    ...: import pandas
    ...: import requests
    ...: import datetime
    ...: 
    ...: def get_historical_prices(tickets, range):
    ...:     request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
    ...:     json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
    ...:     united_dataframe = pandas.DataFrame()
    ...:     for symbol in json:
    ...:         ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
    ...:         ticket_dataframe.insert(0, 'ticket', symbol)
    ...:         united_dataframe = united_dataframe.append(ticket_dataframe)
    ...:     return united_dataframe[['ticket', 'date', 'close']]
    ...: 
    ...: def rsi(dataframe, period, column = 'close'):
    ...:     delta = all_prices[column].diff()
    ...:     up, down = delta.copy(), delta.copy()
    ...:     up[up < 0] = 0
    ...:     down[down > 0] = 0
    ...:     rolling_up = up.ewm(com=period - 1, adjust=False).mean()
    ...:     rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
    ...:     rsi = 100 - 100 / (1 + rolling_up / rolling_down)
    ...:     
    ...:     return rsi
    ...: 
    ...: # Get the data
    ...: tickets = ['AAPL', 'FB', 'TSLA']
    ...: all_prices = get_historical_prices(tickets, '1m')
    ...: 
    ...: print(all_prices.groupby('ticket').apply(rsi, 2))
    ...: 
    ...: 
close   0    1          2          3          4          5          6   \
ticket                                                                   
AAPL   NaN  0.0  76.377953  78.208232  44.065484  16.991057  19.694656   
FB     NaN  0.0  76.377953  78.208232  44.065484  16.991057  19.694656   
TSLA   NaN  0.0  76.377953  78.208232  44.065484  16.991057  19.694656   

close         7         8        9     ...             11         12  \
ticket                                 ...                             
AAPL    3.521704  1.252711  15.2917    ...      76.523444  48.103572   
FB      3.521704  1.252711  15.2917    ...      76.523444  48.103572   
TSLA    3.521704  1.252711  15.2917    ...      76.523444  48.103572   

close          13         14         15       16         17         18  \
ticket                                                                   
AAPL    80.777497  46.146025  23.884925  8.34273  16.897125  75.919398   
FB      80.777497  46.146025  23.884925  8.34273  16.897125  75.919398   
TSLA    80.777497  46.146025  23.884925  8.34273  16.897125  75.919398   

close          19         20  
ticket                        
AAPL    15.705838  12.501725  
FB      15.705838  12.501725  
TSLA    15.705838  12.501725  

[3 rows x 61 columns]