Pandas idxmax()获取与行

时间:2016-07-30 07:22:20

标签: python pandas max argmax

假设我有以下DataFrame Q_df

        (0, 0)  (0, 1)  (0, 2)  (1, 0)  (1, 1)  (1, 2)  (2, 0)  (2, 1)  (2, 2)
(0, 0)   0.000    0.00     0.0    0.64   0.000     0.0   0.512   0.000     0.0
(0, 1)   0.000    0.00     0.8    0.00   0.512     0.0   0.000   0.512     0.0
(0, 2)   0.000    0.64     0.0    0.00   0.000     0.8   0.000   0.000     1.0
(1, 0)   0.512    0.00     0.0    0.00   0.000     0.8   0.512   0.000     0.0
(1, 1)   0.000    0.64     0.0    0.00   0.000     0.0   0.000   0.512     0.0
(1, 2)   0.000    0.00     0.8    0.64   0.000     0.0   0.000   0.000     1.0
(2, 0)   0.512    0.00     0.0    0.64   0.000     0.0   0.000   0.512     0.0
(2, 1)   0.000    0.64     0.0    0.00   0.512     0.0   0.512   0.000     0.0
(2, 2)   0.000    0.00     0.8    0.00   0.000     0.8   0.000   0.000     0.0

使用以下代码生成:

import numpy as np
import pandas as pd

states = list(itertools.product(range(3), repeat=2))

Q = np.array([[0.000,0.000,0.000,0.640,0.000,0.000,0.512,0.000,0.000],
[0.000,0.000,0.800,0.000,0.512,0.000,0.000,0.512,0.000],
[0.000,0.640,0.000,0.000,0.000,0.800,0.000,0.000,1.000],
[0.512,0.000,0.000,0.000,0.000,0.800,0.512,0.000,0.000],
[0.000,0.640,0.000,0.000,0.000,0.000,0.000,0.512,0.000],
[0.000,0.000,0.800,0.640,0.000,0.000,0.000,0.000,1.000],
[0.512,0.000,0.000,0.640,0.000,0.000,0.000,0.512,0.000],
[0.000,0.640,0.000,0.000,0.512,0.000,0.512,0.000,0.000],
[0.000,0.000,0.800,0.000,0.000,0.800,0.000,0.000,0.000]])

Q_df = pd.DataFrame(index=states, columns=states, data=Q)

对于Q的每一行,我想得到与行中最大值对应的列名。如果我试试

policy = Q_df.idxmax()

然后生成的系列看起来像这样:

(0, 0)    (1, 0)
(0, 1)    (0, 2)
(0, 2)    (0, 1)
(1, 0)    (0, 0)
(1, 1)    (0, 1)
(1, 2)    (0, 2)
(2, 0)    (0, 0)
(2, 1)    (0, 1)
(2, 2)    (0, 2)

第一行看起来没问题:第一行的最大元素是0.64,出现在(1,0)列中。第二个也是如此。但是,对于第三行,最大元素为0.8并出现在(1,2)列中,因此我希望policy中的相应值为(1,2),而不是{{} 1}}。

知道这里出了什么问题吗?

1 个答案:

答案 0 :(得分:1)

IIUC,您可以在idxmax中使用axis=1

policy = Q_df.idxmax(axis=1)

(0, 0)    (1, 0)
(0, 1)    (0, 2)
(0, 2)    (2, 2)
(1, 0)    (1, 2)
(1, 1)    (0, 1)
(1, 2)    (2, 2)
(2, 0)    (1, 0)
(2, 1)    (0, 1)
(2, 2)    (0, 2)
dtype: object