假设我有以下DataFrame Q_df
:
(0, 0) (0, 1) (0, 2) (1, 0) (1, 1) (1, 2) (2, 0) (2, 1) (2, 2)
(0, 0) 0.000 0.00 0.0 0.64 0.000 0.0 0.512 0.000 0.0
(0, 1) 0.000 0.00 0.8 0.00 0.512 0.0 0.000 0.512 0.0
(0, 2) 0.000 0.64 0.0 0.00 0.000 0.8 0.000 0.000 1.0
(1, 0) 0.512 0.00 0.0 0.00 0.000 0.8 0.512 0.000 0.0
(1, 1) 0.000 0.64 0.0 0.00 0.000 0.0 0.000 0.512 0.0
(1, 2) 0.000 0.00 0.8 0.64 0.000 0.0 0.000 0.000 1.0
(2, 0) 0.512 0.00 0.0 0.64 0.000 0.0 0.000 0.512 0.0
(2, 1) 0.000 0.64 0.0 0.00 0.512 0.0 0.512 0.000 0.0
(2, 2) 0.000 0.00 0.8 0.00 0.000 0.8 0.000 0.000 0.0
使用以下代码生成:
import numpy as np
import pandas as pd
states = list(itertools.product(range(3), repeat=2))
Q = np.array([[0.000,0.000,0.000,0.640,0.000,0.000,0.512,0.000,0.000],
[0.000,0.000,0.800,0.000,0.512,0.000,0.000,0.512,0.000],
[0.000,0.640,0.000,0.000,0.000,0.800,0.000,0.000,1.000],
[0.512,0.000,0.000,0.000,0.000,0.800,0.512,0.000,0.000],
[0.000,0.640,0.000,0.000,0.000,0.000,0.000,0.512,0.000],
[0.000,0.000,0.800,0.640,0.000,0.000,0.000,0.000,1.000],
[0.512,0.000,0.000,0.640,0.000,0.000,0.000,0.512,0.000],
[0.000,0.640,0.000,0.000,0.512,0.000,0.512,0.000,0.000],
[0.000,0.000,0.800,0.000,0.000,0.800,0.000,0.000,0.000]])
Q_df = pd.DataFrame(index=states, columns=states, data=Q)
对于Q的每一行,我想得到与行中最大值对应的列名。如果我试试
policy = Q_df.idxmax()
然后生成的系列看起来像这样:
(0, 0) (1, 0)
(0, 1) (0, 2)
(0, 2) (0, 1)
(1, 0) (0, 0)
(1, 1) (0, 1)
(1, 2) (0, 2)
(2, 0) (0, 0)
(2, 1) (0, 1)
(2, 2) (0, 2)
第一行看起来没问题:第一行的最大元素是0.64
,出现在(1,0)
列中。第二个也是如此。但是,对于第三行,最大元素为0.8
并出现在(1,2)
列中,因此我希望policy
中的相应值为(1,2)
,而不是{{} 1}}。
知道这里出了什么问题吗?
答案 0 :(得分:1)
IIUC,您可以在idxmax
中使用axis=1
:
policy = Q_df.idxmax(axis=1)
(0, 0) (1, 0)
(0, 1) (0, 2)
(0, 2) (2, 2)
(1, 0) (1, 2)
(1, 1) (0, 1)
(1, 2) (2, 2)
(2, 0) (1, 0)
(2, 1) (0, 1)
(2, 2) (0, 2)
dtype: object