traj0
Out[52]:
state action reward
0 [1.0, 4.0, 6.0] 3.0 4.0
1 [4.0, 6.0, 11.0] 4.0 5.0
2 [6.0, 7.0, 3.0] 3.0 22.0
3 [3.0, 3.0, 2.0] 1.0 10.0
4 [2.0, 9.0, 5.0] 2.0 2.0
假设我有一个像这样的pandas数据框,其中状态列以3元素numpy数组作为条目。
在这里如何查询状态为np.array([3.0,3.0,2.0])
的行?
我知道traj0.query("state == '[3.0,3.0,2.0]'")
有效。但是我不想对查询中的数组值进行硬编码。
我正在寻找类似的东西
x = np.array([3.0,3.0,2.0])
traj0.query('state ==' + x)
=============
这不是重复的问题,因为我先前的问题pandas query with a column consisting of array entries仅适用于每个数组中只有一个值的情况。在这里,我正在寻找数组是否具有多个值。
答案 0 :(得分:0)
import numpy as np
import pandas as pd
df = pd.DataFrame([[np.array([1.0, 4.0, 6.0]), 3.0, 4.0],
[np.array([4.0, 6.0, 11.0]), 4.0, 5.0],
[np.array([6.0, 7.0, 3.0]), 3.0, 22.0],
[np.array([3.0, 3.0, 2.0]), 1.0, 10.0],
[np.array([2.0, 9.0, 5.0]), 2.0, 2.0]
], columns=['state','action','reward'])
x = str(np.array([3.0, 3.0, 2.0]))
df[df.state.astype(str) == x]
// to use pd.query
df['state_str'] = df.state.astype(str)
df.query("state_str == '{}'".format(x))
输出
state action reward
3 [3.0, 3.0, 2.0] 1.0 10.0
答案 1 :(得分:0)
您可以使用df.loc
和使用numpy.array_equal
的lambda函数来做到这一点:
x = [1., 4., 6.]
traj0.loc[df.state.apply(lambda a: np.array_equal(a, x))]
基本上,这会检查state
列中的每个元素是否与x
等价,并仅返回该列匹配的行。
df = pd.DataFrame(data={'state': [[1., 4., 6.], [4., 5., 6.]],
'value': [5, 6]})
print(df.loc[df.state.apply(lambda a: np.array_equal(a, x))])
state value
0 [1.0, 4.0, 6.0] 5
答案 2 :(得分:0)
最好不要在此处使用pd.DataFrame.query
。您可以执行向量化比较,然后使用布尔索引:
x = [3, 3, 2]
mask = (np.array(df['state'].values.tolist()) == x).all(1)
res = df[mask]
print(res)
state action reward
3 [3.0, 3.0, 2.0] 1.0 10.0
通常,您不应该在Pandas系列中存储列表或数组。这是低效率的,并且消除了直接矢量化操作的可能性。在这里,我们必须显式转换为NumPy数组才能进行简单比较。